feat(agent): implement Agent audit module with LangGraph integration
- Introduce new Agent audit functionality for autonomous code security analysis and vulnerability verification. - Add API endpoints for managing Agent tasks and configurations. - Implement UI components for Agent mode selection and embedding model configuration. - Enhance the overall architecture with a focus on RAG (Retrieval-Augmented Generation) for improved code semantic search. - Create a sandbox environment for secure execution of vulnerability tests. - Update documentation to include details on the new Agent audit features and usage instructions.
This commit is contained in:
parent
7c9b9ea933
commit
9bc114af1f
|
|
@ -58,3 +58,4 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,3 +102,4 @@ format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
datefmt = %H:%M:%S
|
datefmt = %H:%M:%S
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,3 +89,4 @@ else:
|
||||||
asyncio.run(run_migrations_online())
|
asyncio.run(run_migrations_online())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,3 +24,4 @@ def downgrade() -> None:
|
||||||
${downgrades if downgrades else "pass"}
|
${downgrades if downgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,255 @@
|
||||||
|
"""Add agent audit tables
|
||||||
|
|
||||||
|
Revision ID: 006_add_agent_tables
|
||||||
|
Revises: 5fc1cc05d5d0
|
||||||
|
Create Date: 2024-01-15 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '006_add_agent_tables'
|
||||||
|
down_revision = '5fc1cc05d5d0'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# 创建 agent_tasks 表
|
||||||
|
op.create_table(
|
||||||
|
'agent_tasks',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('project_id', sa.String(36), sa.ForeignKey('projects.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
|
||||||
|
# 任务基本信息
|
||||||
|
sa.Column('name', sa.String(255), nullable=True),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('task_type', sa.String(50), default='agent_audit'),
|
||||||
|
|
||||||
|
# 任务配置
|
||||||
|
sa.Column('audit_scope', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('target_vulnerabilities', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('verification_level', sa.String(50), default='sandbox'),
|
||||||
|
|
||||||
|
# 分支信息
|
||||||
|
sa.Column('branch_name', sa.String(255), nullable=True),
|
||||||
|
|
||||||
|
# 排除模式
|
||||||
|
sa.Column('exclude_patterns', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# 文件范围
|
||||||
|
sa.Column('target_files', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# LLM 配置
|
||||||
|
sa.Column('llm_config', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# Agent 配置
|
||||||
|
sa.Column('agent_config', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('max_iterations', sa.Integer(), default=50),
|
||||||
|
sa.Column('token_budget', sa.Integer(), default=100000),
|
||||||
|
sa.Column('timeout_seconds', sa.Integer(), default=1800),
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
sa.Column('status', sa.String(20), default='pending'),
|
||||||
|
sa.Column('current_phase', sa.String(50), nullable=True),
|
||||||
|
sa.Column('current_step', sa.String(255), nullable=True),
|
||||||
|
sa.Column('error_message', sa.Text(), nullable=True),
|
||||||
|
|
||||||
|
# 进度统计
|
||||||
|
sa.Column('total_files', sa.Integer(), default=0),
|
||||||
|
sa.Column('indexed_files', sa.Integer(), default=0),
|
||||||
|
sa.Column('analyzed_files', sa.Integer(), default=0),
|
||||||
|
sa.Column('total_chunks', sa.Integer(), default=0),
|
||||||
|
|
||||||
|
# Agent 统计
|
||||||
|
sa.Column('total_iterations', sa.Integer(), default=0),
|
||||||
|
sa.Column('tool_calls_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('tokens_used', sa.Integer(), default=0),
|
||||||
|
|
||||||
|
# 发现统计
|
||||||
|
sa.Column('findings_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('verified_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('false_positive_count', sa.Integer(), default=0),
|
||||||
|
|
||||||
|
# 严重程度统计
|
||||||
|
sa.Column('critical_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('high_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('medium_count', sa.Integer(), default=0),
|
||||||
|
sa.Column('low_count', sa.Integer(), default=0),
|
||||||
|
|
||||||
|
# 质量评分
|
||||||
|
sa.Column('quality_score', sa.Float(), default=0.0),
|
||||||
|
sa.Column('security_score', sa.Float(), default=0.0),
|
||||||
|
|
||||||
|
# 审计计划
|
||||||
|
sa.Column('audit_plan', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), onupdate=sa.func.now()),
|
||||||
|
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
|
||||||
|
# 创建者
|
||||||
|
sa.Column('created_by', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 agent_tasks 索引
|
||||||
|
op.create_index('ix_agent_tasks_project_id', 'agent_tasks', ['project_id'])
|
||||||
|
op.create_index('ix_agent_tasks_status', 'agent_tasks', ['status'])
|
||||||
|
op.create_index('ix_agent_tasks_created_by', 'agent_tasks', ['created_by'])
|
||||||
|
op.create_index('ix_agent_tasks_created_at', 'agent_tasks', ['created_at'])
|
||||||
|
|
||||||
|
# 创建 agent_events 表
|
||||||
|
op.create_table(
|
||||||
|
'agent_events',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
|
||||||
|
# 事件信息
|
||||||
|
sa.Column('event_type', sa.String(50), nullable=False),
|
||||||
|
sa.Column('phase', sa.String(50), nullable=True),
|
||||||
|
|
||||||
|
# 事件内容
|
||||||
|
sa.Column('message', sa.Text(), nullable=True),
|
||||||
|
|
||||||
|
# 工具调用相关
|
||||||
|
sa.Column('tool_name', sa.String(100), nullable=True),
|
||||||
|
sa.Column('tool_input', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('tool_output', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('tool_duration_ms', sa.Integer(), nullable=True),
|
||||||
|
|
||||||
|
# 关联的发现
|
||||||
|
sa.Column('finding_id', sa.String(36), nullable=True),
|
||||||
|
|
||||||
|
# Token 消耗
|
||||||
|
sa.Column('tokens_used', sa.Integer(), default=0),
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# 序号
|
||||||
|
sa.Column('sequence', sa.Integer(), default=0),
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 agent_events 索引
|
||||||
|
op.create_index('ix_agent_events_task_id', 'agent_events', ['task_id'])
|
||||||
|
op.create_index('ix_agent_events_event_type', 'agent_events', ['event_type'])
|
||||||
|
op.create_index('ix_agent_events_sequence', 'agent_events', ['sequence'])
|
||||||
|
op.create_index('ix_agent_events_created_at', 'agent_events', ['created_at'])
|
||||||
|
|
||||||
|
# 创建 agent_findings 表
|
||||||
|
op.create_table(
|
||||||
|
'agent_findings',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False),
|
||||||
|
|
||||||
|
# 漏洞基本信息
|
||||||
|
sa.Column('vulnerability_type', sa.String(100), nullable=False),
|
||||||
|
sa.Column('severity', sa.String(20), nullable=False),
|
||||||
|
sa.Column('title', sa.String(500), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
|
||||||
|
# 位置信息
|
||||||
|
sa.Column('file_path', sa.String(500), nullable=True),
|
||||||
|
sa.Column('line_start', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('line_end', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('column_start', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('column_end', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('function_name', sa.String(255), nullable=True),
|
||||||
|
sa.Column('class_name', sa.String(255), nullable=True),
|
||||||
|
|
||||||
|
# 代码片段
|
||||||
|
sa.Column('code_snippet', sa.Text(), nullable=True),
|
||||||
|
sa.Column('code_context', sa.Text(), nullable=True),
|
||||||
|
|
||||||
|
# 数据流信息
|
||||||
|
sa.Column('source', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sink', sa.Text(), nullable=True),
|
||||||
|
sa.Column('dataflow_path', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# 验证信息
|
||||||
|
sa.Column('status', sa.String(30), default='new'),
|
||||||
|
sa.Column('is_verified', sa.Boolean(), default=False),
|
||||||
|
sa.Column('verification_method', sa.Text(), nullable=True),
|
||||||
|
sa.Column('verification_result', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('verified_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
|
||||||
|
# PoC
|
||||||
|
sa.Column('has_poc', sa.Boolean(), default=False),
|
||||||
|
sa.Column('poc_code', sa.Text(), nullable=True),
|
||||||
|
sa.Column('poc_description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('poc_steps', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# 修复建议
|
||||||
|
sa.Column('suggestion', sa.Text(), nullable=True),
|
||||||
|
sa.Column('fix_code', sa.Text(), nullable=True),
|
||||||
|
sa.Column('fix_description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('references', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# AI 解释
|
||||||
|
sa.Column('ai_explanation', sa.Text(), nullable=True),
|
||||||
|
sa.Column('ai_confidence', sa.Float(), nullable=True),
|
||||||
|
|
||||||
|
# XAI
|
||||||
|
sa.Column('xai_what', sa.Text(), nullable=True),
|
||||||
|
sa.Column('xai_why', sa.Text(), nullable=True),
|
||||||
|
sa.Column('xai_how', sa.Text(), nullable=True),
|
||||||
|
sa.Column('xai_impact', sa.Text(), nullable=True),
|
||||||
|
|
||||||
|
# 关联规则
|
||||||
|
sa.Column('matched_rule_code', sa.String(100), nullable=True),
|
||||||
|
sa.Column('matched_pattern', sa.Text(), nullable=True),
|
||||||
|
|
||||||
|
# CVSS 评分
|
||||||
|
sa.Column('cvss_score', sa.Float(), nullable=True),
|
||||||
|
sa.Column('cvss_vector', sa.String(100), nullable=True),
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('tags', sa.JSON(), nullable=True),
|
||||||
|
|
||||||
|
# 去重标识
|
||||||
|
sa.Column('fingerprint', sa.String(64), nullable=True),
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), onupdate=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 agent_findings 索引
|
||||||
|
op.create_index('ix_agent_findings_task_id', 'agent_findings', ['task_id'])
|
||||||
|
op.create_index('ix_agent_findings_vulnerability_type', 'agent_findings', ['vulnerability_type'])
|
||||||
|
op.create_index('ix_agent_findings_severity', 'agent_findings', ['severity'])
|
||||||
|
op.create_index('ix_agent_findings_file_path', 'agent_findings', ['file_path'])
|
||||||
|
op.create_index('ix_agent_findings_status', 'agent_findings', ['status'])
|
||||||
|
op.create_index('ix_agent_findings_fingerprint', 'agent_findings', ['fingerprint'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# 删除索引和表
|
||||||
|
op.drop_index('ix_agent_findings_fingerprint', 'agent_findings')
|
||||||
|
op.drop_index('ix_agent_findings_status', 'agent_findings')
|
||||||
|
op.drop_index('ix_agent_findings_file_path', 'agent_findings')
|
||||||
|
op.drop_index('ix_agent_findings_severity', 'agent_findings')
|
||||||
|
op.drop_index('ix_agent_findings_vulnerability_type', 'agent_findings')
|
||||||
|
op.drop_index('ix_agent_findings_task_id', 'agent_findings')
|
||||||
|
op.drop_table('agent_findings')
|
||||||
|
|
||||||
|
op.drop_index('ix_agent_events_created_at', 'agent_events')
|
||||||
|
op.drop_index('ix_agent_events_sequence', 'agent_events')
|
||||||
|
op.drop_index('ix_agent_events_event_type', 'agent_events')
|
||||||
|
op.drop_index('ix_agent_events_task_id', 'agent_events')
|
||||||
|
op.drop_table('agent_events')
|
||||||
|
|
||||||
|
op.drop_index('ix_agent_tasks_created_at', 'agent_tasks')
|
||||||
|
op.drop_index('ix_agent_tasks_created_by', 'agent_tasks')
|
||||||
|
op.drop_index('ix_agent_tasks_status', 'agent_tasks')
|
||||||
|
op.drop_index('ix_agent_tasks_project_id', 'agent_tasks')
|
||||||
|
op.drop_table('agent_tasks')
|
||||||
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules
|
from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules, agent_tasks, embedding_config
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||||
|
|
@ -12,3 +12,5 @@ api_router.include_router(config.router, prefix="/config", tags=["config"])
|
||||||
api_router.include_router(database.router, prefix="/database", tags=["database"])
|
api_router.include_router(database.router, prefix="/database", tags=["database"])
|
||||||
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||||
api_router.include_router(rules.router, prefix="/rules", tags=["rules"])
|
api_router.include_router(rules.router, prefix="/rules", tags=["rules"])
|
||||||
|
api_router.include_router(agent_tasks.router, prefix="/agent-tasks", tags=["agent-tasks"])
|
||||||
|
api_router.include_router(embedding_config.router, prefix="/embedding", tags=["embedding"])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,639 @@
|
||||||
|
"""
|
||||||
|
DeepAudit Agent 审计任务 API
|
||||||
|
基于 LangGraph 的 Agent 审计
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import Any, List, Optional, Dict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api import deps
|
||||||
|
from app.db.session import get_db, async_session_factory
|
||||||
|
from app.models.agent_task import (
|
||||||
|
AgentTask, AgentEvent, AgentFinding,
|
||||||
|
AgentTaskStatus, AgentTaskPhase, AgentEventType,
|
||||||
|
VulnerabilitySeverity, FindingStatus,
|
||||||
|
)
|
||||||
|
from app.models.project import Project
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.agent import AgentRunner, EventManager, run_agent_task
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 运行中的任务
|
||||||
|
_running_tasks: Dict[str, AgentRunner] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Schemas ============
|
||||||
|
|
||||||
|
class AgentTaskCreate(BaseModel):
|
||||||
|
"""创建 Agent 任务请求"""
|
||||||
|
project_id: str = Field(..., description="项目 ID")
|
||||||
|
name: Optional[str] = Field(None, description="任务名称")
|
||||||
|
description: Optional[str] = Field(None, description="任务描述")
|
||||||
|
|
||||||
|
# 审计配置
|
||||||
|
audit_scope: Optional[dict] = Field(None, description="审计范围")
|
||||||
|
target_vulnerabilities: Optional[List[str]] = Field(
|
||||||
|
default=["sql_injection", "xss", "command_injection", "path_traversal", "ssrf"],
|
||||||
|
description="目标漏洞类型"
|
||||||
|
)
|
||||||
|
verification_level: str = Field(
|
||||||
|
"sandbox",
|
||||||
|
description="验证级别: analysis_only, sandbox, generate_poc"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 分支
|
||||||
|
branch_name: Optional[str] = Field(None, description="分支名称")
|
||||||
|
|
||||||
|
# 排除模式
|
||||||
|
exclude_patterns: Optional[List[str]] = Field(
|
||||||
|
default=["node_modules", "__pycache__", ".git", "*.min.js"],
|
||||||
|
description="排除模式"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 文件范围
|
||||||
|
target_files: Optional[List[str]] = Field(None, description="指定扫描的文件")
|
||||||
|
|
||||||
|
# Agent 配置
|
||||||
|
max_iterations: int = Field(3, ge=1, le=10, description="最大分析迭代次数")
|
||||||
|
timeout_seconds: int = Field(1800, ge=60, le=7200, description="超时时间(秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTaskResponse(BaseModel):
|
||||||
|
"""Agent 任务响应"""
|
||||||
|
id: str
|
||||||
|
project_id: str
|
||||||
|
name: Optional[str]
|
||||||
|
description: Optional[str]
|
||||||
|
status: str
|
||||||
|
current_phase: Optional[str]
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
total_findings: int = 0
|
||||||
|
verified_findings: int = 0
|
||||||
|
security_score: Optional[int] = None
|
||||||
|
|
||||||
|
# 时间
|
||||||
|
created_at: datetime
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
finished_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
config: Optional[dict] = None
|
||||||
|
|
||||||
|
# 错误信息
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEventResponse(BaseModel):
|
||||||
|
"""Agent 事件响应"""
|
||||||
|
id: str
|
||||||
|
task_id: str
|
||||||
|
event_type: str
|
||||||
|
phase: Optional[str]
|
||||||
|
message: str
|
||||||
|
sequence: int
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
# 可选字段
|
||||||
|
tool_name: Optional[str] = None
|
||||||
|
tool_duration_ms: Optional[int] = None
|
||||||
|
progress_percent: Optional[float] = None
|
||||||
|
finding_id: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFindingResponse(BaseModel):
|
||||||
|
"""Agent 发现响应"""
|
||||||
|
id: str
|
||||||
|
task_id: str
|
||||||
|
vulnerability_type: str
|
||||||
|
severity: str
|
||||||
|
title: str
|
||||||
|
description: Optional[str]
|
||||||
|
file_path: Optional[str]
|
||||||
|
line_start: Optional[int]
|
||||||
|
line_end: Optional[int]
|
||||||
|
code_snippet: Optional[str]
|
||||||
|
|
||||||
|
is_verified: bool
|
||||||
|
confidence: float
|
||||||
|
status: str
|
||||||
|
|
||||||
|
suggestion: Optional[str] = None
|
||||||
|
poc: Optional[dict] = None
|
||||||
|
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class TaskSummaryResponse(BaseModel):
|
||||||
|
"""任务摘要响应"""
|
||||||
|
task_id: str
|
||||||
|
status: str
|
||||||
|
security_score: Optional[int]
|
||||||
|
|
||||||
|
total_findings: int
|
||||||
|
verified_findings: int
|
||||||
|
|
||||||
|
severity_distribution: Dict[str, int]
|
||||||
|
vulnerability_types: Dict[str, int]
|
||||||
|
|
||||||
|
duration_seconds: Optional[int]
|
||||||
|
phases_completed: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 后台任务执行 ============
|
||||||
|
|
||||||
|
async def _execute_agent_task(task_id: str, project_root: str):
|
||||||
|
"""在后台执行 Agent 任务"""
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
try:
|
||||||
|
# 获取任务
|
||||||
|
task = await db.get(AgentTask, task_id, options=[selectinload(AgentTask.project)])
|
||||||
|
if not task:
|
||||||
|
logger.error(f"Task {task_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建 Runner
|
||||||
|
runner = AgentRunner(db, task, project_root)
|
||||||
|
_running_tasks[task_id] = runner
|
||||||
|
|
||||||
|
# 执行
|
||||||
|
result = await runner.run()
|
||||||
|
|
||||||
|
logger.info(f"Task {task_id} completed: {result.get('success', False)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task {task_id} failed: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if task:
|
||||||
|
task.status = AgentTaskStatus.FAILED
|
||||||
|
task.error_message = str(e)
|
||||||
|
task.finished_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理
|
||||||
|
_running_tasks.pop(task_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
@router.post("/", response_model=AgentTaskResponse)
|
||||||
|
async def create_agent_task(
|
||||||
|
request: AgentTaskCreate,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
创建并启动 Agent 审计任务
|
||||||
|
"""
|
||||||
|
# 验证项目
|
||||||
|
project = await db.get(Project, request.project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
|
if project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此项目")
|
||||||
|
|
||||||
|
# 创建任务
|
||||||
|
task = AgentTask(
|
||||||
|
id=str(uuid4()),
|
||||||
|
project_id=project.id,
|
||||||
|
name=request.name or f"Agent Audit - {datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
description=request.description,
|
||||||
|
status=AgentTaskStatus.PENDING,
|
||||||
|
current_phase=AgentTaskPhase.PLANNING,
|
||||||
|
config={
|
||||||
|
"target_vulnerabilities": request.target_vulnerabilities,
|
||||||
|
"verification_level": request.verification_level,
|
||||||
|
"exclude_patterns": request.exclude_patterns,
|
||||||
|
"target_files": request.target_files,
|
||||||
|
"max_iterations": request.max_iterations,
|
||||||
|
"timeout_seconds": request.timeout_seconds,
|
||||||
|
},
|
||||||
|
created_by=current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(task)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(task)
|
||||||
|
|
||||||
|
# 确定项目根目录
|
||||||
|
project_root = _get_project_root(project, task.id)
|
||||||
|
|
||||||
|
# 在后台启动任务
|
||||||
|
background_tasks.add_task(_execute_agent_task, task.id, project_root)
|
||||||
|
|
||||||
|
logger.info(f"Created agent task {task.id} for project {project.name}")
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[AgentTaskResponse])
|
||||||
|
async def list_agent_tasks(
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取 Agent 任务列表
|
||||||
|
"""
|
||||||
|
# 获取用户的项目
|
||||||
|
projects_result = await db.execute(
|
||||||
|
select(Project.id).where(Project.owner_id == current_user.id)
|
||||||
|
)
|
||||||
|
user_project_ids = [p[0] for p in projects_result.fetchall()]
|
||||||
|
|
||||||
|
if not user_project_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
query = select(AgentTask).where(AgentTask.project_id.in_(user_project_ids))
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
query = query.where(AgentTask.project_id == project_id)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
try:
|
||||||
|
status_enum = AgentTaskStatus(status)
|
||||||
|
query = query.where(AgentTask.status == status_enum)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
query = query.order_by(AgentTask.created_at.desc())
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
tasks = result.scalars().all()
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}", response_model=AgentTaskResponse)
|
||||||
|
async def get_agent_task(
|
||||||
|
task_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取 Agent 任务详情
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{task_id}/cancel")
|
||||||
|
async def cancel_agent_task(
|
||||||
|
task_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
取消 Agent 任务
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权操作此任务")
|
||||||
|
|
||||||
|
if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||||
|
raise HTTPException(status_code=400, detail="任务已结束,无法取消")
|
||||||
|
|
||||||
|
# 取消运行中的任务
|
||||||
|
runner = _running_tasks.get(task_id)
|
||||||
|
if runner:
|
||||||
|
runner.cancel()
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
task.status = AgentTaskStatus.CANCELLED
|
||||||
|
task.finished_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"message": "任务已取消", "task_id": task_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}/events")
|
||||||
|
async def stream_agent_events(
|
||||||
|
task_id: str,
|
||||||
|
after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取 Agent 事件流 (SSE)
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
"""生成 SSE 事件流"""
|
||||||
|
last_sequence = after_sequence
|
||||||
|
poll_interval = 0.5
|
||||||
|
max_idle = 300 # 5 分钟无事件后关闭
|
||||||
|
idle_time = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 查询新事件
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(AgentEvent)
|
||||||
|
.where(AgentEvent.task_id == task_id)
|
||||||
|
.where(AgentEvent.sequence > last_sequence)
|
||||||
|
.order_by(AgentEvent.sequence)
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
|
events = result.scalars().all()
|
||||||
|
|
||||||
|
# 获取任务状态
|
||||||
|
current_task = await session.get(AgentTask, task_id)
|
||||||
|
task_status = current_task.status if current_task else None
|
||||||
|
|
||||||
|
if events:
|
||||||
|
idle_time = 0
|
||||||
|
for event in events:
|
||||||
|
last_sequence = event.sequence
|
||||||
|
data = {
|
||||||
|
"id": event.id,
|
||||||
|
"type": event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type),
|
||||||
|
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
|
||||||
|
"message": event.message,
|
||||||
|
"sequence": event.sequence,
|
||||||
|
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||||
|
"progress_percent": event.progress_percent,
|
||||||
|
"tool_name": event.tool_name,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
idle_time += poll_interval
|
||||||
|
|
||||||
|
# 检查任务是否结束
|
||||||
|
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||||
|
yield f"data: {json.dumps({'type': 'task_end', 'status': task_status.value})}\n\n"
|
||||||
|
break
|
||||||
|
|
||||||
|
# 检查空闲超时
|
||||||
|
if idle_time >= max_idle:
|
||||||
|
yield f"data: {json.dumps({'type': 'timeout'})}\n\n"
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse])
|
||||||
|
async def list_agent_events(
|
||||||
|
task_id: str,
|
||||||
|
after_sequence: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=500),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取 Agent 事件列表
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentEvent)
|
||||||
|
.where(AgentEvent.task_id == task_id)
|
||||||
|
.where(AgentEvent.sequence > after_sequence)
|
||||||
|
.order_by(AgentEvent.sequence)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
events = result.scalars().all()
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}/findings", response_model=List[AgentFindingResponse])
|
||||||
|
async def list_agent_findings(
|
||||||
|
task_id: str,
|
||||||
|
severity: Optional[str] = None,
|
||||||
|
verified_only: bool = False,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取 Agent 发现列表
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||||
|
|
||||||
|
query = select(AgentFinding).where(AgentFinding.task_id == task_id)
|
||||||
|
|
||||||
|
if severity:
|
||||||
|
try:
|
||||||
|
sev_enum = VulnerabilitySeverity(severity)
|
||||||
|
query = query.where(AgentFinding.severity == sev_enum)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if verified_only:
|
||||||
|
query = query.where(AgentFinding.is_verified == True)
|
||||||
|
|
||||||
|
# 按严重程度排序
|
||||||
|
severity_order = {
|
||||||
|
VulnerabilitySeverity.CRITICAL: 0,
|
||||||
|
VulnerabilitySeverity.HIGH: 1,
|
||||||
|
VulnerabilitySeverity.MEDIUM: 2,
|
||||||
|
VulnerabilitySeverity.LOW: 3,
|
||||||
|
VulnerabilitySeverity.INFO: 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
query = query.order_by(AgentFinding.severity, AgentFinding.created_at.desc())
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
findings = result.scalars().all()
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}/summary", response_model=TaskSummaryResponse)
|
||||||
|
async def get_task_summary(
|
||||||
|
task_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取任务摘要
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||||
|
|
||||||
|
# 获取所有发现
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentFinding).where(AgentFinding.task_id == task_id)
|
||||||
|
)
|
||||||
|
findings = result.scalars().all()
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
severity_distribution = {}
|
||||||
|
vulnerability_types = {}
|
||||||
|
verified_count = 0
|
||||||
|
|
||||||
|
for f in findings:
|
||||||
|
sev = f.severity.value if hasattr(f.severity, 'value') else str(f.severity)
|
||||||
|
vtype = f.vulnerability_type.value if hasattr(f.vulnerability_type, 'value') else str(f.vulnerability_type)
|
||||||
|
|
||||||
|
severity_distribution[sev] = severity_distribution.get(sev, 0) + 1
|
||||||
|
vulnerability_types[vtype] = vulnerability_types.get(vtype, 0) + 1
|
||||||
|
|
||||||
|
if f.is_verified:
|
||||||
|
verified_count += 1
|
||||||
|
|
||||||
|
# 计算持续时间
|
||||||
|
duration = None
|
||||||
|
if task.started_at and task.finished_at:
|
||||||
|
duration = int((task.finished_at - task.started_at).total_seconds())
|
||||||
|
|
||||||
|
# 获取已完成的阶段
|
||||||
|
phases_result = await db.execute(
|
||||||
|
select(AgentEvent.phase)
|
||||||
|
.where(AgentEvent.task_id == task_id)
|
||||||
|
.where(AgentEvent.event_type == AgentEventType.PHASE_COMPLETE)
|
||||||
|
.distinct()
|
||||||
|
)
|
||||||
|
phases = [p[0].value if p[0] and hasattr(p[0], 'value') else str(p[0]) for p in phases_result.fetchall() if p[0]]
|
||||||
|
|
||||||
|
return TaskSummaryResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
status=task.status.value if hasattr(task.status, 'value') else str(task.status),
|
||||||
|
security_score=task.security_score,
|
||||||
|
total_findings=len(findings),
|
||||||
|
verified_findings=verified_count,
|
||||||
|
severity_distribution=severity_distribution,
|
||||||
|
vulnerability_types=vulnerability_types,
|
||||||
|
duration_seconds=duration,
|
||||||
|
phases_completed=phases,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{task_id}/findings/{finding_id}/status")
|
||||||
|
async def update_finding_status(
|
||||||
|
task_id: str,
|
||||||
|
finding_id: str,
|
||||||
|
status: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
更新发现状态
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权操作")
|
||||||
|
|
||||||
|
finding = await db.get(AgentFinding, finding_id)
|
||||||
|
if not finding or finding.task_id != task_id:
|
||||||
|
raise HTTPException(status_code=404, detail="发现不存在")
|
||||||
|
|
||||||
|
try:
|
||||||
|
finding.status = FindingStatus(status)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail=f"无效的状态: {status}")
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"message": "状态已更新", "finding_id": finding_id, "status": status}
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Helper Functions ============
|
||||||
|
|
||||||
|
def _get_project_root(project: Project, task_id: str) -> str:
|
||||||
|
"""
|
||||||
|
获取项目根目录
|
||||||
|
|
||||||
|
TODO: 实际实现中需要:
|
||||||
|
- 对于 ZIP 项目:解压到临时目录
|
||||||
|
- 对于 Git 仓库:克隆到临时目录
|
||||||
|
"""
|
||||||
|
base_path = f"/tmp/deepaudit/{task_id}"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
|
||||||
|
# 如果项目有存储路径,复制过来
|
||||||
|
if hasattr(project, 'storage_path') and project.storage_path:
|
||||||
|
if os.path.exists(project.storage_path):
|
||||||
|
# 复制项目文件
|
||||||
|
shutil.copytree(project.storage_path, base_path, dirs_exist_ok=True)
|
||||||
|
|
||||||
|
return base_path
|
||||||
|
|
||||||
|
|
@ -0,0 +1,396 @@
|
||||||
|
"""
|
||||||
|
嵌入模型配置 API
|
||||||
|
独立于 LLM 配置,专门用于 RAG 系统的嵌入模型
|
||||||
|
使用 UserConfig.other_config 持久化存储
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api import deps
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.user_config import UserConfig
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Schemas ============
|
||||||
|
|
||||||
|
class EmbeddingProvider(BaseModel):
|
||||||
|
"""嵌入模型提供商"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
models: List[str]
|
||||||
|
requires_api_key: bool
|
||||||
|
default_model: str
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfig(BaseModel):
|
||||||
|
"""嵌入模型配置"""
|
||||||
|
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface")
|
||||||
|
model: str = Field(description="模型名称")
|
||||||
|
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
|
||||||
|
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
|
||||||
|
dimensions: Optional[int] = Field(default=None, description="向量维度 (某些模型支持)")
|
||||||
|
batch_size: int = Field(default=100, description="批处理大小")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfigResponse(BaseModel):
|
||||||
|
"""配置响应"""
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
base_url: Optional[str]
|
||||||
|
dimensions: int
|
||||||
|
batch_size: int
|
||||||
|
# 不返回 API Key
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingRequest(BaseModel):
|
||||||
|
"""测试嵌入请求"""
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
test_text: str = "这是一段测试文本,用于验证嵌入模型是否正常工作。"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingResponse(BaseModel):
|
||||||
|
"""测试嵌入响应"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
sample_embedding: Optional[List[float]] = None # 前 5 个维度
|
||||||
|
latency_ms: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 提供商配置 ============
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
|
||||||
|
EmbeddingProvider(
|
||||||
|
id="openai",
|
||||||
|
name="OpenAI",
|
||||||
|
description="OpenAI 官方嵌入模型,高质量、稳定",
|
||||||
|
models=[
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"text-embedding-3-large",
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
],
|
||||||
|
requires_api_key=True,
|
||||||
|
default_model="text-embedding-3-small",
|
||||||
|
),
|
||||||
|
EmbeddingProvider(
|
||||||
|
id="azure",
|
||||||
|
name="Azure OpenAI",
|
||||||
|
description="Azure 托管的 OpenAI 嵌入模型",
|
||||||
|
models=[
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"text-embedding-3-large",
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
],
|
||||||
|
requires_api_key=True,
|
||||||
|
default_model="text-embedding-3-small",
|
||||||
|
),
|
||||||
|
EmbeddingProvider(
|
||||||
|
id="ollama",
|
||||||
|
name="Ollama (本地)",
|
||||||
|
description="本地运行的开源嵌入模型 (使用 /api/embed 端点)",
|
||||||
|
models=[
|
||||||
|
"nomic-embed-text",
|
||||||
|
"mxbai-embed-large",
|
||||||
|
"all-minilm",
|
||||||
|
"snowflake-arctic-embed",
|
||||||
|
"bge-m3",
|
||||||
|
"qwen3-embedding",
|
||||||
|
],
|
||||||
|
requires_api_key=False,
|
||||||
|
default_model="nomic-embed-text",
|
||||||
|
),
|
||||||
|
EmbeddingProvider(
|
||||||
|
id="cohere",
|
||||||
|
name="Cohere",
|
||||||
|
description="Cohere Embed v2 API (api.cohere.com/v2)",
|
||||||
|
models=[
|
||||||
|
"embed-english-v3.0",
|
||||||
|
"embed-multilingual-v3.0",
|
||||||
|
"embed-english-light-v3.0",
|
||||||
|
"embed-multilingual-light-v3.0",
|
||||||
|
"embed-v4.0",
|
||||||
|
],
|
||||||
|
requires_api_key=True,
|
||||||
|
default_model="embed-multilingual-v3.0",
|
||||||
|
),
|
||||||
|
EmbeddingProvider(
|
||||||
|
id="huggingface",
|
||||||
|
name="HuggingFace",
|
||||||
|
description="HuggingFace Inference Providers (router.huggingface.co)",
|
||||||
|
models=[
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
"sentence-transformers/all-mpnet-base-v2",
|
||||||
|
"BAAI/bge-large-zh-v1.5",
|
||||||
|
"BAAI/bge-m3",
|
||||||
|
],
|
||||||
|
requires_api_key=True,
|
||||||
|
default_model="BAAI/bge-m3",
|
||||||
|
),
|
||||||
|
EmbeddingProvider(
|
||||||
|
id="jina",
|
||||||
|
name="Jina AI",
|
||||||
|
description="Jina AI 嵌入模型,代码嵌入效果好",
|
||||||
|
models=[
|
||||||
|
"jina-embeddings-v2-base-code",
|
||||||
|
"jina-embeddings-v2-base-en",
|
||||||
|
"jina-embeddings-v2-base-zh",
|
||||||
|
],
|
||||||
|
requires_api_key=True,
|
||||||
|
default_model="jina-embeddings-v2-base-code",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 数据库持久化存储 (异步) ============
|
||||||
|
|
||||||
|
EMBEDDING_CONFIG_KEY = "embedding_config"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig:
|
||||||
|
"""从数据库获取嵌入配置(异步)"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(UserConfig).where(UserConfig.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_config = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user_config and user_config.other_config:
|
||||||
|
try:
|
||||||
|
other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config
|
||||||
|
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
|
||||||
|
|
||||||
|
if embedding_data:
|
||||||
|
return EmbeddingConfig(
|
||||||
|
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
|
||||||
|
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
|
||||||
|
api_key=embedding_data.get("api_key"),
|
||||||
|
base_url=embedding_data.get("base_url"),
|
||||||
|
dimensions=embedding_data.get("dimensions"),
|
||||||
|
batch_size=embedding_data.get("batch_size", 100),
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 返回默认配置
|
||||||
|
return EmbeddingConfig(
|
||||||
|
provider=settings.EMBEDDING_PROVIDER,
|
||||||
|
model=settings.EMBEDDING_MODEL,
|
||||||
|
api_key=settings.LLM_API_KEY,
|
||||||
|
base_url=settings.LLM_BASE_URL,
|
||||||
|
batch_size=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: EmbeddingConfig) -> None:
|
||||||
|
"""保存嵌入配置到数据库(异步)"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(UserConfig).where(UserConfig.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_config = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# 准备嵌入配置数据
|
||||||
|
embedding_data = {
|
||||||
|
"provider": config.provider,
|
||||||
|
"model": config.model,
|
||||||
|
"api_key": config.api_key,
|
||||||
|
"base_url": config.base_url,
|
||||||
|
"dimensions": config.dimensions,
|
||||||
|
"batch_size": config.batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_config:
|
||||||
|
# 更新现有配置
|
||||||
|
try:
|
||||||
|
other_config = json.loads(user_config.other_config) if user_config.other_config else {}
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
other_config = {}
|
||||||
|
|
||||||
|
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
|
||||||
|
user_config.other_config = json.dumps(other_config)
|
||||||
|
else:
|
||||||
|
# 创建新配置
|
||||||
|
user_config = UserConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
llm_config="{}",
|
||||||
|
other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}),
|
||||||
|
)
|
||||||
|
db.add(user_config)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
@router.get("/providers", response_model=List[EmbeddingProvider])
|
||||||
|
async def list_embedding_providers(
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取可用的嵌入模型提供商列表
|
||||||
|
"""
|
||||||
|
return EMBEDDING_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", response_model=EmbeddingConfigResponse)
|
||||||
|
async def get_current_config(
|
||||||
|
db: AsyncSession = Depends(deps.get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取当前嵌入模型配置(从数据库读取)
|
||||||
|
"""
|
||||||
|
config = await get_embedding_config_from_db(db, current_user.id)
|
||||||
|
|
||||||
|
# 获取维度
|
||||||
|
dimensions = _get_model_dimensions(config.provider, config.model)
|
||||||
|
|
||||||
|
return EmbeddingConfigResponse(
|
||||||
|
provider=config.provider,
|
||||||
|
model=config.model,
|
||||||
|
base_url=config.base_url,
|
||||||
|
dimensions=dimensions,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/config")
|
||||||
|
async def update_config(
|
||||||
|
config: EmbeddingConfig,
|
||||||
|
db: AsyncSession = Depends(deps.get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
更新嵌入模型配置(持久化到数据库)
|
||||||
|
"""
|
||||||
|
# 验证提供商
|
||||||
|
provider_ids = [p.id for p in EMBEDDING_PROVIDERS]
|
||||||
|
if config.provider not in provider_ids:
|
||||||
|
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
|
||||||
|
|
||||||
|
# 验证模型
|
||||||
|
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
|
||||||
|
if provider and config.model not in provider.models:
|
||||||
|
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
|
||||||
|
|
||||||
|
# 检查 API Key
|
||||||
|
if provider and provider.requires_api_key and not config.api_key:
|
||||||
|
raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key")
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
await save_embedding_config_to_db(db, current_user.id, config)
|
||||||
|
|
||||||
|
return {"message": "配置已保存", "provider": config.provider, "model": config.model}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test", response_model=TestEmbeddingResponse)
|
||||||
|
async def test_embedding(
|
||||||
|
request: TestEmbeddingRequest,
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
测试嵌入模型配置
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 创建临时嵌入服务
|
||||||
|
from app.services.rag.embeddings import EmbeddingService
|
||||||
|
|
||||||
|
service = EmbeddingService(
|
||||||
|
provider=request.provider,
|
||||||
|
model=request.model,
|
||||||
|
api_key=request.api_key,
|
||||||
|
base_url=request.base_url,
|
||||||
|
cache_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行嵌入
|
||||||
|
embedding = await service.embed(request.test_text)
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
return TestEmbeddingResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"嵌入成功! 维度: {len(embedding)}",
|
||||||
|
dimensions=len(embedding),
|
||||||
|
sample_embedding=embedding[:5], # 返回前 5 维
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return TestEmbeddingResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"嵌入失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models/{provider}")
|
||||||
|
async def get_provider_models(
|
||||||
|
provider: str,
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
获取指定提供商的模型列表
|
||||||
|
"""
|
||||||
|
provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None)
|
||||||
|
|
||||||
|
if not provider_info:
|
||||||
|
raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"provider": provider,
|
||||||
|
"models": provider_info.models,
|
||||||
|
"default_model": provider_info.default_model,
|
||||||
|
"requires_api_key": provider_info.requires_api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_dimensions(provider: str, model: str) -> int:
|
||||||
|
"""获取模型维度"""
|
||||||
|
dimensions_map = {
|
||||||
|
# OpenAI
|
||||||
|
"text-embedding-3-small": 1536,
|
||||||
|
"text-embedding-3-large": 3072,
|
||||||
|
"text-embedding-ada-002": 1536,
|
||||||
|
|
||||||
|
# Ollama
|
||||||
|
"nomic-embed-text": 768,
|
||||||
|
"mxbai-embed-large": 1024,
|
||||||
|
"all-minilm": 384,
|
||||||
|
"snowflake-arctic-embed": 1024,
|
||||||
|
|
||||||
|
# Cohere
|
||||||
|
"embed-english-v3.0": 1024,
|
||||||
|
"embed-multilingual-v3.0": 1024,
|
||||||
|
"embed-english-light-v3.0": 384,
|
||||||
|
"embed-multilingual-light-v3.0": 384,
|
||||||
|
|
||||||
|
# HuggingFace
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
||||||
|
"sentence-transformers/all-mpnet-base-v2": 768,
|
||||||
|
"BAAI/bge-large-zh-v1.5": 1024,
|
||||||
|
"BAAI/bge-m3": 1024,
|
||||||
|
|
||||||
|
# Jina
|
||||||
|
"jina-embeddings-v2-base-code": 768,
|
||||||
|
"jina-embeddings-v2-base-en": 768,
|
||||||
|
"jina-embeddings-v2-base-zh": 768,
|
||||||
|
}
|
||||||
|
|
||||||
|
return dimensions_map.get(model, 768)
|
||||||
|
|
||||||
|
|
@ -209,3 +209,4 @@ async def remove_project_member(
|
||||||
return {"message": "成员已移除"}
|
return {"message": "成员已移除"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -224,3 +224,4 @@ async def toggle_user_status(
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,32 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
# 输出语言配置 - 支持 zh-CN(中文)和 en-US(英文)
|
# 输出语言配置 - 支持 zh-CN(中文)和 en-US(英文)
|
||||||
OUTPUT_LANGUAGE: str = "zh-CN"
|
OUTPUT_LANGUAGE: str = "zh-CN"
|
||||||
|
|
||||||
|
# ============ Agent 模块配置 ============
|
||||||
|
|
||||||
|
# 嵌入模型配置
|
||||||
|
EMBEDDING_PROVIDER: str = "openai" # openai, ollama, litellm
|
||||||
|
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# 向量数据库配置
|
||||||
|
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录
|
||||||
|
|
||||||
|
# Agent 配置
|
||||||
|
AGENT_MAX_ITERATIONS: int = 50 # Agent 最大迭代次数
|
||||||
|
AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算
|
||||||
|
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟)
|
||||||
|
|
||||||
|
# 沙箱配置
|
||||||
|
SANDBOX_IMAGE: str = "deepaudit-sandbox:latest" # 沙箱 Docker 镜像
|
||||||
|
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
|
||||||
|
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制
|
||||||
|
SANDBOX_TIMEOUT: int = 60 # 沙箱命令超时(秒)
|
||||||
|
SANDBOX_NETWORK_MODE: str = "none" # 沙箱网络模式 (none, bridge)
|
||||||
|
|
||||||
|
# RAG 配置
|
||||||
|
RAG_CHUNK_SIZE: int = 1500 # 代码块大小(Token)
|
||||||
|
RAG_CHUNK_OVERLAP: int = 50 # 代码块重叠(Token)
|
||||||
|
RAG_TOP_K: int = 10 # 检索返回数量
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,4 @@ def get_password_hash(password: str) -> str:
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,3 +11,4 @@ class Base:
|
||||||
return cls.__name__.lower() + "s"
|
return cls.__name__.lower() + "s"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
@ -16,3 +17,14 @@ async def get_db():
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_session_factory():
|
||||||
|
"""Async context manager for creating database sessions"""
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,15 @@
|
||||||
from .user import User
|
from .user import User
|
||||||
|
from .user_config import UserConfig
|
||||||
from .project import Project, ProjectMember
|
from .project import Project, ProjectMember
|
||||||
from .audit import AuditTask, AuditIssue
|
from .audit import AuditTask, AuditIssue
|
||||||
from .analysis import InstantAnalysis
|
from .analysis import InstantAnalysis
|
||||||
from .prompt_template import PromptTemplate
|
from .prompt_template import PromptTemplate
|
||||||
from .audit_rule import AuditRuleSet, AuditRule
|
from .audit_rule import AuditRuleSet, AuditRule
|
||||||
|
from .agent_task import (
|
||||||
|
AgentTask, AgentEvent, AgentFinding,
|
||||||
|
AgentTaskStatus, AgentTaskPhase, AgentEventType,
|
||||||
|
VulnerabilitySeverity, VulnerabilityType, FindingStatus
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,443 @@
|
||||||
|
"""
|
||||||
|
Agent 审计任务模型
|
||||||
|
支持 AI Agent 自主漏洞挖掘和验证
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, TYPE_CHECKING
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column, String, Integer, Float, Text, Boolean,
|
||||||
|
DateTime, ForeignKey, Enum as SQLEnum, JSON
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
from app.db.base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .project import Project
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTaskStatus:
|
||||||
|
"""Agent 任务状态"""
|
||||||
|
PENDING = "pending" # 等待执行
|
||||||
|
INITIALIZING = "initializing" # 初始化中
|
||||||
|
PLANNING = "planning" # 规划阶段
|
||||||
|
INDEXING = "indexing" # 索引阶段
|
||||||
|
ANALYZING = "analyzing" # 分析阶段
|
||||||
|
VERIFYING = "verifying" # 验证阶段
|
||||||
|
REPORTING = "reporting" # 报告生成
|
||||||
|
COMPLETED = "completed" # 已完成
|
||||||
|
FAILED = "failed" # 失败
|
||||||
|
CANCELLED = "cancelled" # 已取消
|
||||||
|
PAUSED = "paused" # 已暂停
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTaskPhase:
|
||||||
|
"""Agent 执行阶段"""
|
||||||
|
PLANNING = "planning"
|
||||||
|
INDEXING = "indexing"
|
||||||
|
RECONNAISSANCE = "reconnaissance"
|
||||||
|
ANALYSIS = "analysis"
|
||||||
|
VERIFICATION = "verification"
|
||||||
|
REPORTING = "reporting"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTask(Base):
|
||||||
|
"""Agent 审计任务"""
|
||||||
|
__tablename__ = "agent_tasks"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
|
||||||
|
# 任务基本信息
|
||||||
|
name = Column(String(255), nullable=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
task_type = Column(String(50), default="agent_audit")
|
||||||
|
|
||||||
|
# 任务配置
|
||||||
|
audit_scope = Column(JSON, nullable=True) # 审计范围配置
|
||||||
|
target_vulnerabilities = Column(JSON, nullable=True) # 目标漏洞类型
|
||||||
|
verification_level = Column(String(50), default="sandbox") # analysis_only, sandbox, generate_poc
|
||||||
|
|
||||||
|
# 分支信息(仓库项目)
|
||||||
|
branch_name = Column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# 排除模式
|
||||||
|
exclude_patterns = Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
# 文件范围
|
||||||
|
target_files = Column(JSON, nullable=True) # 指定扫描的文件列表
|
||||||
|
|
||||||
|
# LLM 配置
|
||||||
|
llm_config = Column(JSON, nullable=True) # LLM 配置
|
||||||
|
|
||||||
|
# Agent 配置
|
||||||
|
agent_config = Column(JSON, nullable=True) # Agent 特定配置
|
||||||
|
max_iterations = Column(Integer, default=50) # 最大迭代次数
|
||||||
|
token_budget = Column(Integer, default=100000) # Token 预算
|
||||||
|
timeout_seconds = Column(Integer, default=1800) # 超时时间(秒)
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
status = Column(String(20), default=AgentTaskStatus.PENDING)
|
||||||
|
current_phase = Column(String(50), nullable=True)
|
||||||
|
current_step = Column(String(255), nullable=True) # 当前执行步骤描述
|
||||||
|
error_message = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 进度统计
|
||||||
|
total_files = Column(Integer, default=0)
|
||||||
|
indexed_files = Column(Integer, default=0)
|
||||||
|
analyzed_files = Column(Integer, default=0)
|
||||||
|
total_chunks = Column(Integer, default=0) # 代码块总数
|
||||||
|
|
||||||
|
# Agent 统计
|
||||||
|
total_iterations = Column(Integer, default=0) # Agent 迭代次数
|
||||||
|
tool_calls_count = Column(Integer, default=0) # 工具调用次数
|
||||||
|
tokens_used = Column(Integer, default=0) # 已使用 Token 数
|
||||||
|
|
||||||
|
# 发现统计
|
||||||
|
findings_count = Column(Integer, default=0) # 发现总数
|
||||||
|
verified_count = Column(Integer, default=0) # 已验证数
|
||||||
|
false_positive_count = Column(Integer, default=0) # 误报数
|
||||||
|
|
||||||
|
# 严重程度统计
|
||||||
|
critical_count = Column(Integer, default=0)
|
||||||
|
high_count = Column(Integer, default=0)
|
||||||
|
medium_count = Column(Integer, default=0)
|
||||||
|
low_count = Column(Integer, default=0)
|
||||||
|
|
||||||
|
# 质量评分
|
||||||
|
quality_score = Column(Float, default=0.0)
|
||||||
|
security_score = Column(Float, default=0.0)
|
||||||
|
|
||||||
|
# 审计计划
|
||||||
|
audit_plan = Column(JSON, nullable=True) # Agent 生成的审计计划
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# 创建者
|
||||||
|
created_by = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||||
|
|
||||||
|
# 关联关系
|
||||||
|
project = relationship("Project", back_populates="agent_tasks")
|
||||||
|
events = relationship("AgentEvent", back_populates="task", cascade="all, delete-orphan", order_by="AgentEvent.created_at")
|
||||||
|
findings = relationship("AgentFinding", back_populates="task", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<AgentTask {self.id} - {self.status}>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress_percentage(self) -> float:
|
||||||
|
"""计算进度百分比"""
|
||||||
|
if self.status == AgentTaskStatus.COMPLETED:
|
||||||
|
return 100.0
|
||||||
|
if self.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
phase_weights = {
|
||||||
|
AgentTaskPhase.PLANNING: 5,
|
||||||
|
AgentTaskPhase.INDEXING: 15,
|
||||||
|
AgentTaskPhase.RECONNAISSANCE: 10,
|
||||||
|
AgentTaskPhase.ANALYSIS: 50,
|
||||||
|
AgentTaskPhase.VERIFICATION: 15,
|
||||||
|
AgentTaskPhase.REPORTING: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
completed_weight = 0
|
||||||
|
current_found = False
|
||||||
|
|
||||||
|
for phase, weight in phase_weights.items():
|
||||||
|
if phase == self.current_phase:
|
||||||
|
current_found = True
|
||||||
|
# 估算当前阶段进度
|
||||||
|
if phase == AgentTaskPhase.INDEXING and self.total_files > 0:
|
||||||
|
completed_weight += weight * (self.indexed_files / self.total_files)
|
||||||
|
elif phase == AgentTaskPhase.ANALYSIS and self.total_files > 0:
|
||||||
|
completed_weight += weight * (self.analyzed_files / self.total_files)
|
||||||
|
else:
|
||||||
|
completed_weight += weight * 0.5
|
||||||
|
break
|
||||||
|
elif not current_found:
|
||||||
|
completed_weight += weight
|
||||||
|
|
||||||
|
return min(completed_weight, 99.0)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEventType:
|
||||||
|
"""Agent 事件类型"""
|
||||||
|
# 系统事件
|
||||||
|
TASK_START = "task_start"
|
||||||
|
TASK_COMPLETE = "task_complete"
|
||||||
|
TASK_ERROR = "task_error"
|
||||||
|
TASK_CANCEL = "task_cancel"
|
||||||
|
|
||||||
|
# 阶段事件
|
||||||
|
PHASE_START = "phase_start"
|
||||||
|
PHASE_COMPLETE = "phase_complete"
|
||||||
|
|
||||||
|
# Agent 思考
|
||||||
|
THINKING = "thinking"
|
||||||
|
PLANNING = "planning"
|
||||||
|
DECISION = "decision"
|
||||||
|
|
||||||
|
# 工具调用
|
||||||
|
TOOL_CALL = "tool_call"
|
||||||
|
TOOL_RESULT = "tool_result"
|
||||||
|
TOOL_ERROR = "tool_error"
|
||||||
|
|
||||||
|
# RAG 相关
|
||||||
|
RAG_QUERY = "rag_query"
|
||||||
|
RAG_RESULT = "rag_result"
|
||||||
|
|
||||||
|
# 发现相关
|
||||||
|
FINDING_NEW = "finding_new"
|
||||||
|
FINDING_UPDATE = "finding_update"
|
||||||
|
FINDING_VERIFIED = "finding_verified"
|
||||||
|
FINDING_FALSE_POSITIVE = "finding_false_positive"
|
||||||
|
|
||||||
|
# 沙箱相关
|
||||||
|
SANDBOX_START = "sandbox_start"
|
||||||
|
SANDBOX_EXEC = "sandbox_exec"
|
||||||
|
SANDBOX_RESULT = "sandbox_result"
|
||||||
|
SANDBOX_ERROR = "sandbox_error"
|
||||||
|
|
||||||
|
# 进度
|
||||||
|
PROGRESS = "progress"
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
INFO = "info"
|
||||||
|
WARNING = "warning"
|
||||||
|
ERROR = "error"
|
||||||
|
DEBUG = "debug"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEvent(Base):
|
||||||
|
"""Agent 执行事件(用于实时日志和回放)"""
|
||||||
|
__tablename__ = "agent_events"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
|
||||||
|
# 事件信息
|
||||||
|
event_type = Column(String(50), nullable=False, index=True)
|
||||||
|
phase = Column(String(50), nullable=True)
|
||||||
|
|
||||||
|
# 事件内容
|
||||||
|
message = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 工具调用相关
|
||||||
|
tool_name = Column(String(100), nullable=True)
|
||||||
|
tool_input = Column(JSON, nullable=True)
|
||||||
|
tool_output = Column(JSON, nullable=True)
|
||||||
|
tool_duration_ms = Column(Integer, nullable=True) # 工具执行时长(毫秒)
|
||||||
|
|
||||||
|
# 关联的发现
|
||||||
|
finding_id = Column(String(36), nullable=True)
|
||||||
|
|
||||||
|
# Token 消耗
|
||||||
|
tokens_used = Column(Integer, default=0)
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
event_metadata = Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
# 序号(用于排序)
|
||||||
|
sequence = Column(Integer, default=0, index=True)
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
||||||
|
|
||||||
|
# 关联关系
|
||||||
|
task = relationship("AgentTask", back_populates="events")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<AgentEvent {self.event_type} - {self.message[:50] if self.message else ''}>"
|
||||||
|
|
||||||
|
def to_sse_dict(self) -> dict:
|
||||||
|
"""转换为 SSE 事件格式"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"type": self.event_type,
|
||||||
|
"phase": self.phase,
|
||||||
|
"message": self.message,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"tool_input": self.tool_input,
|
||||||
|
"tool_output": self.tool_output,
|
||||||
|
"tool_duration_ms": self.tool_duration_ms,
|
||||||
|
"finding_id": self.finding_id,
|
||||||
|
"tokens_used": self.tokens_used,
|
||||||
|
"metadata": self.event_metadata,
|
||||||
|
"sequence": self.sequence,
|
||||||
|
"timestamp": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VulnerabilitySeverity:
|
||||||
|
"""漏洞严重程度"""
|
||||||
|
CRITICAL = "critical"
|
||||||
|
HIGH = "high"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
LOW = "low"
|
||||||
|
INFO = "info"
|
||||||
|
|
||||||
|
|
||||||
|
class VulnerabilityType:
|
||||||
|
"""漏洞类型"""
|
||||||
|
SQL_INJECTION = "sql_injection"
|
||||||
|
NOSQL_INJECTION = "nosql_injection"
|
||||||
|
XSS = "xss"
|
||||||
|
COMMAND_INJECTION = "command_injection"
|
||||||
|
CODE_INJECTION = "code_injection"
|
||||||
|
PATH_TRAVERSAL = "path_traversal"
|
||||||
|
FILE_INCLUSION = "file_inclusion"
|
||||||
|
SSRF = "ssrf"
|
||||||
|
XXE = "xxe"
|
||||||
|
DESERIALIZATION = "deserialization"
|
||||||
|
AUTH_BYPASS = "auth_bypass"
|
||||||
|
IDOR = "idor"
|
||||||
|
SENSITIVE_DATA_EXPOSURE = "sensitive_data_exposure"
|
||||||
|
HARDCODED_SECRET = "hardcoded_secret"
|
||||||
|
WEAK_CRYPTO = "weak_crypto"
|
||||||
|
RACE_CONDITION = "race_condition"
|
||||||
|
BUSINESS_LOGIC = "business_logic"
|
||||||
|
MEMORY_CORRUPTION = "memory_corruption"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
class FindingStatus:
|
||||||
|
"""发现状态"""
|
||||||
|
NEW = "new" # 新发现
|
||||||
|
ANALYZING = "analyzing" # 分析中
|
||||||
|
VERIFIED = "verified" # 已验证
|
||||||
|
FALSE_POSITIVE = "false_positive" # 误报
|
||||||
|
NEEDS_REVIEW = "needs_review" # 需要人工审核
|
||||||
|
FIXED = "fixed" # 已修复
|
||||||
|
WONT_FIX = "wont_fix" # 不修复
|
||||||
|
DUPLICATE = "duplicate" # 重复
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFinding(Base):
|
||||||
|
"""Agent 发现的漏洞"""
|
||||||
|
__tablename__ = "agent_findings"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
|
||||||
|
# 漏洞基本信息
|
||||||
|
vulnerability_type = Column(String(100), nullable=False, index=True)
|
||||||
|
severity = Column(String(20), nullable=False, index=True)
|
||||||
|
title = Column(String(500), nullable=False)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 位置信息
|
||||||
|
file_path = Column(String(500), nullable=True, index=True)
|
||||||
|
line_start = Column(Integer, nullable=True)
|
||||||
|
line_end = Column(Integer, nullable=True)
|
||||||
|
column_start = Column(Integer, nullable=True)
|
||||||
|
column_end = Column(Integer, nullable=True)
|
||||||
|
function_name = Column(String(255), nullable=True)
|
||||||
|
class_name = Column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# 代码片段
|
||||||
|
code_snippet = Column(Text, nullable=True)
|
||||||
|
code_context = Column(Text, nullable=True) # 更多上下文
|
||||||
|
|
||||||
|
# 数据流信息
|
||||||
|
source = Column(Text, nullable=True) # 污点源
|
||||||
|
sink = Column(Text, nullable=True) # 危险函数
|
||||||
|
dataflow_path = Column(JSON, nullable=True) # 数据流路径
|
||||||
|
|
||||||
|
# 验证信息
|
||||||
|
status = Column(String(30), default=FindingStatus.NEW, index=True)
|
||||||
|
is_verified = Column(Boolean, default=False)
|
||||||
|
verification_method = Column(Text, nullable=True)
|
||||||
|
verification_result = Column(JSON, nullable=True)
|
||||||
|
verified_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# PoC
|
||||||
|
has_poc = Column(Boolean, default=False)
|
||||||
|
poc_code = Column(Text, nullable=True)
|
||||||
|
poc_description = Column(Text, nullable=True)
|
||||||
|
poc_steps = Column(JSON, nullable=True) # 复现步骤
|
||||||
|
|
||||||
|
# 修复建议
|
||||||
|
suggestion = Column(Text, nullable=True)
|
||||||
|
fix_code = Column(Text, nullable=True)
|
||||||
|
fix_description = Column(Text, nullable=True)
|
||||||
|
references = Column(JSON, nullable=True) # 参考链接 CWE, OWASP 等
|
||||||
|
|
||||||
|
# AI 解释
|
||||||
|
ai_explanation = Column(Text, nullable=True)
|
||||||
|
ai_confidence = Column(Float, nullable=True) # AI 置信度 0-1
|
||||||
|
|
||||||
|
# XAI (可解释AI)
|
||||||
|
xai_what = Column(Text, nullable=True)
|
||||||
|
xai_why = Column(Text, nullable=True)
|
||||||
|
xai_how = Column(Text, nullable=True)
|
||||||
|
xai_impact = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 关联规则
|
||||||
|
matched_rule_code = Column(String(100), nullable=True)
|
||||||
|
matched_pattern = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# CVSS 评分(可选)
|
||||||
|
cvss_score = Column(Float, nullable=True)
|
||||||
|
cvss_vector = Column(String(100), nullable=True)
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
finding_metadata = Column(JSON, nullable=True)
|
||||||
|
tags = Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
# 去重标识
|
||||||
|
fingerprint = Column(String(64), nullable=True, index=True) # 用于去重的指纹
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
# 关联关系
|
||||||
|
task = relationship("AgentTask", back_populates="findings")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<AgentFinding {self.vulnerability_type} - {self.severity} - {self.file_path}>"
|
||||||
|
|
||||||
|
def generate_fingerprint(self) -> str:
|
||||||
|
"""生成去重指纹"""
|
||||||
|
import hashlib
|
||||||
|
components = [
|
||||||
|
self.vulnerability_type or "",
|
||||||
|
self.file_path or "",
|
||||||
|
str(self.line_start or 0),
|
||||||
|
self.function_name or "",
|
||||||
|
(self.code_snippet or "")[:200],
|
||||||
|
]
|
||||||
|
content = "|".join(components)
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"vulnerability_type": self.vulnerability_type,
|
||||||
|
"severity": self.severity,
|
||||||
|
"title": self.title,
|
||||||
|
"description": self.description,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"line_start": self.line_start,
|
||||||
|
"line_end": self.line_end,
|
||||||
|
"code_snippet": self.code_snippet,
|
||||||
|
"status": self.status,
|
||||||
|
"is_verified": self.is_verified,
|
||||||
|
"has_poc": self.has_poc,
|
||||||
|
"poc_code": self.poc_code,
|
||||||
|
"suggestion": self.suggestion,
|
||||||
|
"fix_code": self.fix_code,
|
||||||
|
"ai_explanation": self.ai_explanation,
|
||||||
|
"ai_confidence": self.ai_confidence,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
}
|
||||||
|
|
@ -23,3 +23,4 @@ class InstantAnalysis(Base):
|
||||||
user = relationship("User", backref="instant_analyses")
|
user = relationship("User", backref="instant_analyses")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class Project(Base):
|
||||||
owner = relationship("User", backref="projects")
|
owner = relationship("User", backref="projects")
|
||||||
members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan")
|
members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan")
|
||||||
tasks = relationship("AuditTask", back_populates="project", cascade="all, delete-orphan")
|
tasks = relationship("AuditTask", back_populates="project", cascade="all, delete-orphan")
|
||||||
|
agent_tasks = relationship("AgentTask", back_populates="project", cascade="all, delete-orphan")
|
||||||
|
|
||||||
class ProjectMember(Base):
|
class ProjectMember(Base):
|
||||||
__tablename__ = "project_members"
|
__tablename__ = "project_members"
|
||||||
|
|
@ -49,3 +50,4 @@ class ProjectMember(Base):
|
||||||
user = relationship("User", backref="project_memberships")
|
user = relationship("User", backref="project_memberships")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,3 +24,4 @@ class User(Base):
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,3 +29,4 @@ class UserConfig(Base):
|
||||||
user = relationship("User", backref="config")
|
user = relationship("User", backref="config")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,3 +9,4 @@ class TokenPayload(BaseModel):
|
||||||
sub: Optional[str] = None
|
sub: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,3 +40,4 @@ class UserListResponse(BaseModel):
|
||||||
limit: int
|
limit: int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
"""
|
||||||
|
DeepAudit Agent 服务模块
|
||||||
|
基于 LangGraph 的 AI Agent 代码安全审计
|
||||||
|
|
||||||
|
架构:
|
||||||
|
LangGraph 状态图工作流
|
||||||
|
|
||||||
|
START → Recon → Analysis ⟲ → Verification → Report → END
|
||||||
|
|
||||||
|
节点:
|
||||||
|
- Recon: 信息收集 (项目结构、技术栈、入口点)
|
||||||
|
- Analysis: 漏洞分析 (静态分析、RAG、模式匹配)
|
||||||
|
- Verification: 漏洞验证 (LLM 验证、沙箱测试)
|
||||||
|
- Report: 报告生成
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 从 graph 模块导入主要组件
|
||||||
|
from .graph import (
|
||||||
|
AgentRunner,
|
||||||
|
run_agent_task,
|
||||||
|
LLMService,
|
||||||
|
AuditState,
|
||||||
|
create_audit_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 事件管理
|
||||||
|
from .event_manager import EventManager, AgentEventEmitter
|
||||||
|
|
||||||
|
# Agent 类
|
||||||
|
from .agents import (
|
||||||
|
BaseAgent, AgentConfig, AgentResult,
|
||||||
|
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 核心 Runner
|
||||||
|
"AgentRunner",
|
||||||
|
"run_agent_task",
|
||||||
|
"LLMService",
|
||||||
|
|
||||||
|
# LangGraph
|
||||||
|
"AuditState",
|
||||||
|
"create_audit_graph",
|
||||||
|
|
||||||
|
# 事件管理
|
||||||
|
"EventManager",
|
||||||
|
"AgentEventEmitter",
|
||||||
|
|
||||||
|
# Agent 类
|
||||||
|
"BaseAgent",
|
||||||
|
"AgentConfig",
|
||||||
|
"AgentResult",
|
||||||
|
"OrchestratorAgent",
|
||||||
|
"ReconAgent",
|
||||||
|
"AnalysisAgent",
|
||||||
|
"VerificationAgent",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""
|
||||||
|
混合 Agent 架构
|
||||||
|
包含 Orchestrator、Recon、Analysis 和 Verification Agent
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseAgent, AgentConfig, AgentResult
|
||||||
|
from .orchestrator import OrchestratorAgent
|
||||||
|
from .recon import ReconAgent
|
||||||
|
from .analysis import AnalysisAgent
|
||||||
|
from .verification import VerificationAgent
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAgent",
|
||||||
|
"AgentConfig",
|
||||||
|
"AgentResult",
|
||||||
|
"OrchestratorAgent",
|
||||||
|
"ReconAgent",
|
||||||
|
"AnalysisAgent",
|
||||||
|
"VerificationAgent",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,469 @@
|
||||||
|
"""
|
||||||
|
Analysis Agent (漏洞分析层)
|
||||||
|
负责代码审计、RAG 查询、模式匹配、数据流分析
|
||||||
|
|
||||||
|
类型: ReAct
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,负责深度代码安全分析。
|
||||||
|
|
||||||
|
## 你的职责
|
||||||
|
1. 使用静态分析工具快速扫描
|
||||||
|
2. 使用 RAG 进行语义代码搜索
|
||||||
|
3. 追踪数据流(从用户输入到危险函数)
|
||||||
|
4. 分析业务逻辑漏洞
|
||||||
|
5. 评估漏洞严重程度
|
||||||
|
|
||||||
|
## 你可以使用的工具
|
||||||
|
### 外部扫描工具
|
||||||
|
- semgrep_scan: Semgrep 静态分析(推荐首先使用)
|
||||||
|
- bandit_scan: Python 安全扫描
|
||||||
|
|
||||||
|
### RAG 语义搜索
|
||||||
|
- rag_query: 语义代码搜索
|
||||||
|
- security_search: 安全相关代码搜索
|
||||||
|
- function_context: 函数上下文分析
|
||||||
|
|
||||||
|
### 深度分析
|
||||||
|
- pattern_match: 危险模式匹配
|
||||||
|
- code_analysis: LLM 深度代码分析
|
||||||
|
- dataflow_analysis: 数据流追踪
|
||||||
|
- vulnerability_validation: 漏洞验证
|
||||||
|
|
||||||
|
### 文件操作
|
||||||
|
- read_file: 读取文件
|
||||||
|
- search_code: 关键字搜索
|
||||||
|
|
||||||
|
## 分析策略
|
||||||
|
1. **快速扫描**: 先用 Semgrep 快速发现问题
|
||||||
|
2. **语义搜索**: 用 RAG 找到相关代码
|
||||||
|
3. **深度分析**: 对可疑代码进行 LLM 分析
|
||||||
|
4. **数据流追踪**: 追踪用户输入的流向
|
||||||
|
|
||||||
|
## 重点关注
|
||||||
|
- SQL 注入、NoSQL 注入
|
||||||
|
- XSS(反射型、存储型、DOM型)
|
||||||
|
- 命令注入、代码注入
|
||||||
|
- 路径遍历、任意文件访问
|
||||||
|
- SSRF、XXE
|
||||||
|
- 不安全的反序列化
|
||||||
|
- 认证/授权绕过
|
||||||
|
- 敏感信息泄露
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
发现漏洞时,返回结构化信息:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"findings": [
|
||||||
|
{
|
||||||
|
"vulnerability_type": "漏洞类型",
|
||||||
|
"severity": "critical/high/medium/low",
|
||||||
|
"title": "漏洞标题",
|
||||||
|
"description": "详细描述",
|
||||||
|
"file_path": "文件路径",
|
||||||
|
"line_start": 行号,
|
||||||
|
"code_snippet": "代码片段",
|
||||||
|
"source": "污点源",
|
||||||
|
"sink": "危险函数",
|
||||||
|
"suggestion": "修复建议",
|
||||||
|
"needs_verification": true/false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
请系统性地分析代码,发现真实的安全漏洞。"""
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
漏洞分析 Agent
|
||||||
|
|
||||||
|
使用 ReAct 模式进行迭代分析
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_service,
|
||||||
|
tools: Dict[str, Any],
|
||||||
|
event_emitter=None,
|
||||||
|
):
|
||||||
|
config = AgentConfig(
|
||||||
|
name="Analysis",
|
||||||
|
agent_type=AgentType.ANALYSIS,
|
||||||
|
pattern=AgentPattern.REACT,
|
||||||
|
max_iterations=30,
|
||||||
|
system_prompt=ANALYSIS_SYSTEM_PROMPT,
|
||||||
|
tools=[
|
||||||
|
"semgrep_scan", "bandit_scan",
|
||||||
|
"rag_query", "security_search", "function_context",
|
||||||
|
"pattern_match", "code_analysis", "dataflow_analysis",
|
||||||
|
"vulnerability_validation",
|
||||||
|
"read_file", "search_code",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
super().__init__(config, llm_service, tools, event_emitter)
|
||||||
|
|
||||||
|
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||||
|
"""执行漏洞分析"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
phase_name = input_data.get("phase_name", "analysis")
|
||||||
|
project_info = input_data.get("project_info", {})
|
||||||
|
config = input_data.get("config", {})
|
||||||
|
plan = input_data.get("plan", {})
|
||||||
|
previous_results = input_data.get("previous_results", {})
|
||||||
|
|
||||||
|
# 从之前的 Recon 结果获取信息
|
||||||
|
recon_data = previous_results.get("recon", {}).get("data", {})
|
||||||
|
high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", []))
|
||||||
|
tech_stack = recon_data.get("tech_stack", {})
|
||||||
|
entry_points = recon_data.get("entry_points", [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_findings = []
|
||||||
|
|
||||||
|
# 1. 静态分析阶段
|
||||||
|
if phase_name in ["static_analysis", "analysis"]:
|
||||||
|
await self.emit_thinking("执行静态代码分析...")
|
||||||
|
static_findings = await self._run_static_analysis(tech_stack)
|
||||||
|
all_findings.extend(static_findings)
|
||||||
|
|
||||||
|
# 2. 深度分析阶段
|
||||||
|
if phase_name in ["deep_analysis", "analysis"]:
|
||||||
|
await self.emit_thinking("执行深度漏洞分析...")
|
||||||
|
|
||||||
|
# 分析入口点
|
||||||
|
deep_findings = await self._analyze_entry_points(entry_points)
|
||||||
|
all_findings.extend(deep_findings)
|
||||||
|
|
||||||
|
# 分析高风险区域
|
||||||
|
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
|
||||||
|
all_findings.extend(risk_findings)
|
||||||
|
|
||||||
|
# 语义搜索常见漏洞
|
||||||
|
vuln_types = config.get("target_vulnerabilities", [
|
||||||
|
"sql_injection", "xss", "command_injection",
|
||||||
|
"path_traversal", "ssrf", "hardcoded_secret",
|
||||||
|
])
|
||||||
|
|
||||||
|
for vuln_type in vuln_types[:5]: # 限制数量
|
||||||
|
if self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.emit_thinking(f"搜索 {vuln_type} 相关代码...")
|
||||||
|
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
|
||||||
|
all_findings.extend(vuln_findings)
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
all_findings = self._deduplicate_findings(all_findings)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"info",
|
||||||
|
f"分析完成: 发现 {len(all_findings)} 个潜在漏洞"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={"findings": all_findings},
|
||||||
|
iterations=self._iteration,
|
||||||
|
tool_calls=self._tool_calls,
|
||||||
|
tokens_used=self._total_tokens,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Analysis agent failed: {e}", exc_info=True)
|
||||||
|
return AgentResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
async def _run_static_analysis(self, tech_stack: Dict) -> List[Dict]:
|
||||||
|
"""运行静态分析工具"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
# Semgrep 扫描
|
||||||
|
semgrep_tool = self.tools.get("semgrep_scan")
|
||||||
|
if semgrep_tool:
|
||||||
|
await self.emit_tool_call("semgrep_scan", {"rules": "p/security-audit"})
|
||||||
|
|
||||||
|
result = await semgrep_tool.execute(rules="p/security-audit", max_results=30)
|
||||||
|
|
||||||
|
if result.success and result.metadata.get("findings_count", 0) > 0:
|
||||||
|
for finding in result.metadata.get("findings", []):
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": self._map_semgrep_rule(finding.get("check_id", "")),
|
||||||
|
"severity": self._map_semgrep_severity(finding.get("extra", {}).get("severity", "")),
|
||||||
|
"title": finding.get("check_id", "Semgrep Finding"),
|
||||||
|
"description": finding.get("extra", {}).get("message", ""),
|
||||||
|
"file_path": finding.get("path", ""),
|
||||||
|
"line_start": finding.get("start", {}).get("line", 0),
|
||||||
|
"code_snippet": finding.get("extra", {}).get("lines", ""),
|
||||||
|
"source": "semgrep",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Bandit 扫描 (Python)
|
||||||
|
languages = tech_stack.get("languages", [])
|
||||||
|
if "Python" in languages:
|
||||||
|
bandit_tool = self.tools.get("bandit_scan")
|
||||||
|
if bandit_tool:
|
||||||
|
await self.emit_tool_call("bandit_scan", {})
|
||||||
|
result = await bandit_tool.execute()
|
||||||
|
|
||||||
|
if result.success and result.metadata.get("findings_count", 0) > 0:
|
||||||
|
for finding in result.metadata.get("findings", []):
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": self._map_bandit_test(finding.get("test_id", "")),
|
||||||
|
"severity": finding.get("issue_severity", "medium").lower(),
|
||||||
|
"title": finding.get("test_name", "Bandit Finding"),
|
||||||
|
"description": finding.get("issue_text", ""),
|
||||||
|
"file_path": finding.get("filename", ""),
|
||||||
|
"line_start": finding.get("line_number", 0),
|
||||||
|
"code_snippet": finding.get("code", ""),
|
||||||
|
"source": "bandit",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
async def _analyze_entry_points(self, entry_points: List[Dict]) -> List[Dict]:
|
||||||
|
"""分析入口点"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
code_analysis_tool = self.tools.get("code_analysis")
|
||||||
|
read_tool = self.tools.get("read_file")
|
||||||
|
|
||||||
|
if not code_analysis_tool or not read_tool:
|
||||||
|
return findings
|
||||||
|
|
||||||
|
# 分析前几个入口点
|
||||||
|
for ep in entry_points[:10]:
|
||||||
|
if self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
file_path = ep.get("file", "")
|
||||||
|
line = ep.get("line", 1)
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
read_result = await read_tool.execute(
|
||||||
|
file_path=file_path,
|
||||||
|
start_line=max(1, line - 20),
|
||||||
|
end_line=line + 50,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not read_result.success:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 深度分析
|
||||||
|
analysis_result = await code_analysis_tool.execute(
|
||||||
|
code=read_result.data,
|
||||||
|
file_path=file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||||
|
for issue in analysis_result.metadata["issues"]:
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": issue.get("type", "unknown"),
|
||||||
|
"severity": issue.get("severity", "medium"),
|
||||||
|
"title": issue.get("title", "Security Issue"),
|
||||||
|
"description": issue.get("description", ""),
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": issue.get("line", line),
|
||||||
|
"code_snippet": issue.get("code_snippet", ""),
|
||||||
|
"suggestion": issue.get("suggestion", ""),
|
||||||
|
"source": "code_analysis",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
|
||||||
|
"""分析高风险区域"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
pattern_tool = self.tools.get("pattern_match")
|
||||||
|
read_tool = self.tools.get("read_file")
|
||||||
|
search_tool = self.tools.get("search_code")
|
||||||
|
|
||||||
|
if not search_tool:
|
||||||
|
return findings
|
||||||
|
|
||||||
|
# 在高风险区域搜索危险模式
|
||||||
|
dangerous_patterns = [
|
||||||
|
("execute(", "sql_injection"),
|
||||||
|
("eval(", "code_injection"),
|
||||||
|
("system(", "command_injection"),
|
||||||
|
("exec(", "command_injection"),
|
||||||
|
("innerHTML", "xss"),
|
||||||
|
("document.write", "xss"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, vuln_type in dangerous_patterns[:5]:
|
||||||
|
if self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||||||
|
|
||||||
|
if result.success and result.metadata.get("matches", 0) > 0:
|
||||||
|
for match in result.metadata.get("results", [])[:3]:
|
||||||
|
file_path = match.get("file", "")
|
||||||
|
|
||||||
|
# 检查是否在高风险区域
|
||||||
|
in_high_risk = any(
|
||||||
|
area in file_path for area in high_risk_areas
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_high_risk or True: # 暂时包含所有
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": vuln_type,
|
||||||
|
"severity": "high" if in_high_risk else "medium",
|
||||||
|
"title": f"疑似 {vuln_type}: {pattern}",
|
||||||
|
"description": f"在 {file_path} 中发现危险模式 {pattern}",
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": match.get("line", 0),
|
||||||
|
"code_snippet": match.get("match", ""),
|
||||||
|
"source": "pattern_search",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
|
||||||
|
"""搜索特定漏洞模式"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
security_tool = self.tools.get("security_search")
|
||||||
|
if not security_tool:
|
||||||
|
return findings
|
||||||
|
|
||||||
|
result = await security_tool.execute(
|
||||||
|
vulnerability_type=vuln_type,
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.success and result.metadata.get("results_count", 0) > 0:
|
||||||
|
for item in result.metadata.get("results", [])[:5]:
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": vuln_type,
|
||||||
|
"severity": "medium",
|
||||||
|
"title": f"疑似 {vuln_type}",
|
||||||
|
"description": f"通过语义搜索发现可能存在 {vuln_type}",
|
||||||
|
"file_path": item.get("file_path", ""),
|
||||||
|
"line_start": item.get("line_start", 0),
|
||||||
|
"code_snippet": item.get("content", "")[:500],
|
||||||
|
"source": "rag_search",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
|
||||||
|
"""去重发现"""
|
||||||
|
seen = set()
|
||||||
|
unique = []
|
||||||
|
|
||||||
|
for finding in findings:
|
||||||
|
key = (
|
||||||
|
finding.get("file_path", ""),
|
||||||
|
finding.get("line_start", 0),
|
||||||
|
finding.get("vulnerability_type", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
unique.append(finding)
|
||||||
|
|
||||||
|
return unique
|
||||||
|
|
||||||
|
def _map_semgrep_rule(self, rule_id: str) -> str:
|
||||||
|
"""映射 Semgrep 规则到漏洞类型"""
|
||||||
|
rule_lower = rule_id.lower()
|
||||||
|
|
||||||
|
if "sql" in rule_lower:
|
||||||
|
return "sql_injection"
|
||||||
|
elif "xss" in rule_lower:
|
||||||
|
return "xss"
|
||||||
|
elif "command" in rule_lower or "injection" in rule_lower:
|
||||||
|
return "command_injection"
|
||||||
|
elif "path" in rule_lower or "traversal" in rule_lower:
|
||||||
|
return "path_traversal"
|
||||||
|
elif "ssrf" in rule_lower:
|
||||||
|
return "ssrf"
|
||||||
|
elif "deserial" in rule_lower:
|
||||||
|
return "deserialization"
|
||||||
|
elif "secret" in rule_lower or "password" in rule_lower or "key" in rule_lower:
|
||||||
|
return "hardcoded_secret"
|
||||||
|
elif "crypto" in rule_lower:
|
||||||
|
return "weak_crypto"
|
||||||
|
else:
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
def _map_semgrep_severity(self, severity: str) -> str:
|
||||||
|
"""映射 Semgrep 严重程度"""
|
||||||
|
mapping = {
|
||||||
|
"ERROR": "high",
|
||||||
|
"WARNING": "medium",
|
||||||
|
"INFO": "low",
|
||||||
|
}
|
||||||
|
return mapping.get(severity, "medium")
|
||||||
|
|
||||||
|
def _map_bandit_test(self, test_id: str) -> str:
|
||||||
|
"""映射 Bandit 测试到漏洞类型"""
|
||||||
|
mappings = {
|
||||||
|
"B101": "assert_used",
|
||||||
|
"B102": "exec_used",
|
||||||
|
"B103": "hardcoded_password",
|
||||||
|
"B104": "hardcoded_bind_all",
|
||||||
|
"B105": "hardcoded_password",
|
||||||
|
"B106": "hardcoded_password",
|
||||||
|
"B107": "hardcoded_password",
|
||||||
|
"B108": "hardcoded_tmp",
|
||||||
|
"B301": "deserialization",
|
||||||
|
"B302": "deserialization",
|
||||||
|
"B303": "weak_crypto",
|
||||||
|
"B304": "weak_crypto",
|
||||||
|
"B305": "weak_crypto",
|
||||||
|
"B306": "weak_crypto",
|
||||||
|
"B307": "code_injection",
|
||||||
|
"B308": "code_injection",
|
||||||
|
"B310": "ssrf",
|
||||||
|
"B311": "weak_random",
|
||||||
|
"B312": "telnet",
|
||||||
|
"B501": "ssl_verify",
|
||||||
|
"B502": "ssl_verify",
|
||||||
|
"B503": "ssl_verify",
|
||||||
|
"B504": "ssl_verify",
|
||||||
|
"B505": "weak_crypto",
|
||||||
|
"B506": "yaml_load",
|
||||||
|
"B507": "ssh_key",
|
||||||
|
"B601": "command_injection",
|
||||||
|
"B602": "command_injection",
|
||||||
|
"B603": "command_injection",
|
||||||
|
"B604": "command_injection",
|
||||||
|
"B605": "command_injection",
|
||||||
|
"B606": "command_injection",
|
||||||
|
"B607": "command_injection",
|
||||||
|
"B608": "sql_injection",
|
||||||
|
"B609": "sql_injection",
|
||||||
|
"B610": "sql_injection",
|
||||||
|
"B611": "sql_injection",
|
||||||
|
"B701": "xss",
|
||||||
|
"B702": "xss",
|
||||||
|
"B703": "xss",
|
||||||
|
}
|
||||||
|
return mappings.get(test_id, "other")
|
||||||
|
|
||||||
|
|
@ -0,0 +1,284 @@
|
||||||
|
"""
|
||||||
|
Agent 基类
|
||||||
|
定义 Agent 的基本接口和通用功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(Enum):
|
||||||
|
"""Agent 类型"""
|
||||||
|
ORCHESTRATOR = "orchestrator"
|
||||||
|
RECON = "recon"
|
||||||
|
ANALYSIS = "analysis"
|
||||||
|
VERIFICATION = "verification"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPattern(Enum):
|
||||||
|
"""Agent 运行模式"""
|
||||||
|
REACT = "react" # 反应式:思考-行动-观察循环
|
||||||
|
PLAN_AND_EXECUTE = "plan_execute" # 计划执行:先规划后执行
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
"""Agent 配置"""
|
||||||
|
name: str
|
||||||
|
agent_type: AgentType
|
||||||
|
pattern: AgentPattern = AgentPattern.REACT
|
||||||
|
|
||||||
|
# LLM 配置
|
||||||
|
model: Optional[str] = None
|
||||||
|
temperature: float = 0.1
|
||||||
|
max_tokens: int = 4096
|
||||||
|
|
||||||
|
# 执行限制
|
||||||
|
max_iterations: int = 20
|
||||||
|
timeout_seconds: int = 600
|
||||||
|
|
||||||
|
# 工具配置
|
||||||
|
tools: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 系统提示词
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResult:
|
||||||
|
"""Agent 执行结果"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
# 执行统计
|
||||||
|
iterations: int = 0
|
||||||
|
tool_calls: int = 0
|
||||||
|
tokens_used: int = 0
|
||||||
|
duration_ms: int = 0
|
||||||
|
|
||||||
|
# 中间结果
|
||||||
|
intermediate_steps: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error,
|
||||||
|
"iterations": self.iterations,
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
"tokens_used": self.tokens_used,
|
||||||
|
"duration_ms": self.duration_ms,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""
|
||||||
|
Agent 基类
|
||||||
|
所有 Agent 需要继承此类并实现核心方法
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: AgentConfig,
|
||||||
|
llm_service,
|
||||||
|
tools: Dict[str, Any],
|
||||||
|
event_emitter=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化 Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Agent 配置
|
||||||
|
llm_service: LLM 服务
|
||||||
|
tools: 可用工具字典
|
||||||
|
event_emitter: 事件发射器
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.llm_service = llm_service
|
||||||
|
self.tools = tools
|
||||||
|
self.event_emitter = event_emitter
|
||||||
|
|
||||||
|
# 运行状态
|
||||||
|
self._iteration = 0
|
||||||
|
self._total_tokens = 0
|
||||||
|
self._tool_calls = 0
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self.config.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_type(self) -> AgentType:
|
||||||
|
return self.config.agent_type
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||||
|
"""
|
||||||
|
执行 Agent 任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent 执行结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
"""取消执行"""
|
||||||
|
self._cancelled = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_cancelled(self) -> bool:
|
||||||
|
return self._cancelled
|
||||||
|
|
||||||
|
async def emit_event(
|
||||||
|
self,
|
||||||
|
event_type: str,
|
||||||
|
message: str,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""发射事件"""
|
||||||
|
if self.event_emitter:
|
||||||
|
from ..event_manager import AgentEventData
|
||||||
|
await self.event_emitter.emit(AgentEventData(
|
||||||
|
event_type=event_type,
|
||||||
|
message=message,
|
||||||
|
**kwargs
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_thinking(self, message: str):
|
||||||
|
"""发射思考事件"""
|
||||||
|
await self.emit_event("thinking", f"[{self.name}] {message}")
|
||||||
|
|
||||||
|
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
|
||||||
|
"""发射工具调用事件"""
|
||||||
|
await self.emit_event(
|
||||||
|
"tool_call",
|
||||||
|
f"[{self.name}] 调用工具: {tool_name}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
|
||||||
|
"""发射工具结果事件"""
|
||||||
|
await self.emit_event(
|
||||||
|
"tool_result",
|
||||||
|
f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)",
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
调用工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
**kwargs: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具执行结果
|
||||||
|
"""
|
||||||
|
tool = self.tools.get(tool_name)
|
||||||
|
if not tool:
|
||||||
|
logger.warning(f"Tool not found: {tool_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._tool_calls += 1
|
||||||
|
await self.emit_tool_call(tool_name, kwargs)
|
||||||
|
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
result = await tool.execute(**kwargs)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start) * 1000)
|
||||||
|
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def call_llm(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
调用 LLM
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
tools: 可用工具描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM 响应
|
||||||
|
"""
|
||||||
|
self._iteration += 1
|
||||||
|
|
||||||
|
# 这里应该调用实际的 LLM 服务
|
||||||
|
# 使用 LangChain 或直接调用 API
|
||||||
|
try:
|
||||||
|
response = await self.llm_service.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.get("usage"):
|
||||||
|
self._total_tokens += response["usage"].get("total_tokens", 0)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM call failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_tool_descriptions(self) -> List[Dict[str, Any]]:
|
||||||
|
"""获取工具描述(用于 LLM)"""
|
||||||
|
descriptions = []
|
||||||
|
|
||||||
|
for name, tool in self.tools.items():
|
||||||
|
if name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
desc = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"description": tool.description,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加参数 schema
|
||||||
|
if hasattr(tool, 'args_schema') and tool.args_schema:
|
||||||
|
desc["function"]["parameters"] = tool.args_schema.schema()
|
||||||
|
|
||||||
|
descriptions.append(desc)
|
||||||
|
|
||||||
|
return descriptions
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取执行统计"""
|
||||||
|
return {
|
||||||
|
"agent": self.name,
|
||||||
|
"type": self.agent_type.value,
|
||||||
|
"iterations": self._iteration,
|
||||||
|
"tool_calls": self._tool_calls,
|
||||||
|
"tokens_used": self._total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,381 @@
|
||||||
|
"""
|
||||||
|
Orchestrator Agent (编排层)
|
||||||
|
负责任务分解、子 Agent 调度和结果汇总
|
||||||
|
|
||||||
|
类型: Plan-and-Execute
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuditPlan:
|
||||||
|
"""审计计划"""
|
||||||
|
phases: List[Dict[str, Any]]
|
||||||
|
high_risk_areas: List[str]
|
||||||
|
focus_vulnerabilities: List[str]
|
||||||
|
estimated_steps: int
|
||||||
|
priority_files: List[str]
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent,负责协调整个安全审计流程。
|
||||||
|
|
||||||
|
## 你的职责
|
||||||
|
1. 分析项目信息,制定审计计划
|
||||||
|
2. 调度子 Agent(Recon、Analysis、Verification)执行任务
|
||||||
|
3. 汇总审计结果,生成报告
|
||||||
|
|
||||||
|
## 审计流程
|
||||||
|
1. **信息收集阶段**: 调度 Recon Agent 收集项目信息
|
||||||
|
- 项目结构分析
|
||||||
|
- 技术栈识别
|
||||||
|
- 入口点识别
|
||||||
|
- 依赖分析
|
||||||
|
|
||||||
|
2. **漏洞分析阶段**: 调度 Analysis Agent 进行代码分析
|
||||||
|
- 静态代码分析
|
||||||
|
- 语义搜索
|
||||||
|
- 模式匹配
|
||||||
|
- 数据流追踪
|
||||||
|
|
||||||
|
3. **漏洞验证阶段**: 调度 Verification Agent 验证发现
|
||||||
|
- 漏洞确认
|
||||||
|
- PoC 生成
|
||||||
|
- 沙箱测试
|
||||||
|
|
||||||
|
4. **报告生成阶段**: 汇总所有发现,生成最终报告
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
当生成审计计划时,返回 JSON:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"phases": [
|
||||||
|
{"name": "阶段名", "description": "描述", "agent": "agent_type"}
|
||||||
|
],
|
||||||
|
"high_risk_areas": ["高风险目录/文件"],
|
||||||
|
"focus_vulnerabilities": ["重点漏洞类型"],
|
||||||
|
"priority_files": ["优先审计的文件"],
|
||||||
|
"estimated_steps": 数字
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
请基于项目信息制定合理的审计计划。"""
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
编排 Agent
|
||||||
|
|
||||||
|
使用 Plan-and-Execute 模式:
|
||||||
|
1. 首先生成审计计划
|
||||||
|
2. 按计划调度子 Agent
|
||||||
|
3. 收集结果并汇总
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_service,
|
||||||
|
tools: Dict[str, Any],
|
||||||
|
event_emitter=None,
|
||||||
|
sub_agents: Optional[Dict[str, BaseAgent]] = None,
|
||||||
|
):
|
||||||
|
config = AgentConfig(
|
||||||
|
name="Orchestrator",
|
||||||
|
agent_type=AgentType.ORCHESTRATOR,
|
||||||
|
pattern=AgentPattern.PLAN_AND_EXECUTE,
|
||||||
|
max_iterations=10,
|
||||||
|
system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
super().__init__(config, llm_service, tools, event_emitter)
|
||||||
|
|
||||||
|
self.sub_agents = sub_agents or {}
|
||||||
|
|
||||||
|
def register_sub_agent(self, name: str, agent: BaseAgent):
|
||||||
|
"""注册子 Agent"""
|
||||||
|
self.sub_agents[name] = agent
|
||||||
|
|
||||||
|
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||||
|
"""
|
||||||
|
执行编排任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: {
|
||||||
|
"project_info": 项目信息,
|
||||||
|
"config": 审计配置,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
project_info = input_data.get("project_info", {})
|
||||||
|
config = input_data.get("config", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.emit_thinking("开始制定审计计划...")
|
||||||
|
|
||||||
|
# 1. 生成审计计划
|
||||||
|
plan = await self._create_audit_plan(project_info, config)
|
||||||
|
|
||||||
|
if not plan:
|
||||||
|
return AgentResult(
|
||||||
|
success=False,
|
||||||
|
error="无法生成审计计划",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"planning",
|
||||||
|
f"审计计划已生成,共 {len(plan.phases)} 个阶段",
|
||||||
|
metadata={"plan": plan.__dict__}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 执行各阶段
|
||||||
|
all_findings = []
|
||||||
|
phase_results = {}
|
||||||
|
|
||||||
|
for phase in plan.phases:
|
||||||
|
if self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
phase_name = phase.get("name", "unknown")
|
||||||
|
agent_type = phase.get("agent", "analysis")
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_start",
|
||||||
|
f"开始 {phase_name} 阶段",
|
||||||
|
phase=phase_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调度对应的子 Agent
|
||||||
|
result = await self._execute_phase(
|
||||||
|
phase_name=phase_name,
|
||||||
|
agent_type=agent_type,
|
||||||
|
project_info=project_info,
|
||||||
|
config=config,
|
||||||
|
plan=plan,
|
||||||
|
previous_results=phase_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
phase_results[phase_name] = result
|
||||||
|
|
||||||
|
if result.success and result.data:
|
||||||
|
if isinstance(result.data, dict):
|
||||||
|
findings = result.data.get("findings", [])
|
||||||
|
all_findings.extend(findings)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_complete",
|
||||||
|
f"{phase_name} 阶段完成",
|
||||||
|
phase=phase_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 汇总结果
|
||||||
|
await self.emit_thinking("汇总审计结果...")
|
||||||
|
|
||||||
|
summary = await self._generate_summary(
|
||||||
|
plan=plan,
|
||||||
|
phase_results=phase_results,
|
||||||
|
all_findings=all_findings,
|
||||||
|
)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"plan": plan.__dict__,
|
||||||
|
"findings": all_findings,
|
||||||
|
"summary": summary,
|
||||||
|
"phase_results": {k: v.to_dict() for k, v in phase_results.items()},
|
||||||
|
},
|
||||||
|
iterations=self._iteration,
|
||||||
|
tool_calls=self._tool_calls,
|
||||||
|
tokens_used=self._total_tokens,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Orchestrator failed: {e}", exc_info=True)
|
||||||
|
return AgentResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_audit_plan(
|
||||||
|
self,
|
||||||
|
project_info: Dict[str, Any],
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> Optional[AuditPlan]:
|
||||||
|
"""生成审计计划"""
|
||||||
|
# 构建 prompt
|
||||||
|
prompt = f"""基于以下项目信息,制定安全审计计划。
|
||||||
|
|
||||||
|
## 项目信息
|
||||||
|
- 名称: {project_info.get('name', 'unknown')}
|
||||||
|
- 语言: {project_info.get('languages', [])}
|
||||||
|
- 文件数量: {project_info.get('file_count', 0)}
|
||||||
|
- 目录结构: {project_info.get('structure', {})}
|
||||||
|
|
||||||
|
## 用户配置
|
||||||
|
- 目标漏洞: {config.get('target_vulnerabilities', [])}
|
||||||
|
- 验证级别: {config.get('verification_level', 'sandbox')}
|
||||||
|
- 排除模式: {config.get('exclude_patterns', [])}
|
||||||
|
|
||||||
|
请生成审计计划,返回 JSON 格式。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用 LLM
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": self.config.system_prompt},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.llm_service.chat_completion_raw(
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.get("content", "")
|
||||||
|
|
||||||
|
# 解析 JSON
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 提取 JSON
|
||||||
|
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||||
|
if json_match:
|
||||||
|
plan_data = json.loads(json_match.group())
|
||||||
|
|
||||||
|
return AuditPlan(
|
||||||
|
phases=plan_data.get("phases", self._default_phases()),
|
||||||
|
high_risk_areas=plan_data.get("high_risk_areas", []),
|
||||||
|
focus_vulnerabilities=plan_data.get("focus_vulnerabilities", []),
|
||||||
|
estimated_steps=plan_data.get("estimated_steps", 30),
|
||||||
|
priority_files=plan_data.get("priority_files", []),
|
||||||
|
metadata=plan_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用默认计划
|
||||||
|
return AuditPlan(
|
||||||
|
phases=self._default_phases(),
|
||||||
|
high_risk_areas=["src/", "api/", "controllers/", "routes/"],
|
||||||
|
focus_vulnerabilities=["sql_injection", "xss", "command_injection"],
|
||||||
|
estimated_steps=30,
|
||||||
|
priority_files=[],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create audit plan: {e}")
|
||||||
|
return AuditPlan(
|
||||||
|
phases=self._default_phases(),
|
||||||
|
high_risk_areas=[],
|
||||||
|
focus_vulnerabilities=[],
|
||||||
|
estimated_steps=30,
|
||||||
|
priority_files=[],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _default_phases(self) -> List[Dict[str, Any]]:
|
||||||
|
"""默认审计阶段"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "recon",
|
||||||
|
"description": "信息收集 - 分析项目结构和技术栈",
|
||||||
|
"agent": "recon",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "static_analysis",
|
||||||
|
"description": "静态分析 - 使用外部工具快速扫描",
|
||||||
|
"agent": "analysis",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "deep_analysis",
|
||||||
|
"description": "深度分析 - AI 驱动的代码审计",
|
||||||
|
"agent": "analysis",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "verification",
|
||||||
|
"description": "漏洞验证 - 确认发现的漏洞",
|
||||||
|
"agent": "verification",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _execute_phase(
|
||||||
|
self,
|
||||||
|
phase_name: str,
|
||||||
|
agent_type: str,
|
||||||
|
project_info: Dict[str, Any],
|
||||||
|
config: Dict[str, Any],
|
||||||
|
plan: AuditPlan,
|
||||||
|
previous_results: Dict[str, AgentResult],
|
||||||
|
) -> AgentResult:
|
||||||
|
"""执行审计阶段"""
|
||||||
|
agent = self.sub_agents.get(agent_type)
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
logger.warning(f"Agent not found: {agent_type}")
|
||||||
|
return AgentResult(success=False, error=f"Agent {agent_type} not found")
|
||||||
|
|
||||||
|
# 构建阶段输入
|
||||||
|
phase_input = {
|
||||||
|
"phase_name": phase_name,
|
||||||
|
"project_info": project_info,
|
||||||
|
"config": config,
|
||||||
|
"plan": plan.__dict__,
|
||||||
|
"previous_results": {k: v.to_dict() for k, v in previous_results.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行子 Agent
|
||||||
|
return await agent.run(phase_input)
|
||||||
|
|
||||||
|
async def _generate_summary(
|
||||||
|
self,
|
||||||
|
plan: AuditPlan,
|
||||||
|
phase_results: Dict[str, AgentResult],
|
||||||
|
all_findings: List[Dict],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""生成审计摘要"""
|
||||||
|
# 统计漏洞
|
||||||
|
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||||
|
type_counts = {}
|
||||||
|
verified_count = 0
|
||||||
|
|
||||||
|
for finding in all_findings:
|
||||||
|
sev = finding.get("severity", "low")
|
||||||
|
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||||
|
|
||||||
|
vtype = finding.get("vulnerability_type", "other")
|
||||||
|
type_counts[vtype] = type_counts.get(vtype, 0) + 1
|
||||||
|
|
||||||
|
if finding.get("is_verified"):
|
||||||
|
verified_count += 1
|
||||||
|
|
||||||
|
# 计算安全评分
|
||||||
|
base_score = 100
|
||||||
|
deductions = (
|
||||||
|
severity_counts["critical"] * 20 +
|
||||||
|
severity_counts["high"] * 10 +
|
||||||
|
severity_counts["medium"] * 5 +
|
||||||
|
severity_counts["low"] * 2
|
||||||
|
)
|
||||||
|
security_score = max(0, base_score - deductions)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_findings": len(all_findings),
|
||||||
|
"verified_count": verified_count,
|
||||||
|
"severity_distribution": severity_counts,
|
||||||
|
"vulnerability_types": type_counts,
|
||||||
|
"security_score": security_score,
|
||||||
|
"phases_completed": len(phase_results),
|
||||||
|
"high_risk_areas": plan.high_risk_areas,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,435 @@
|
||||||
|
"""
|
||||||
|
Recon Agent (信息收集层)
|
||||||
|
负责项目结构分析、技术栈识别、入口点识别
|
||||||
|
|
||||||
|
类型: ReAct
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。
|
||||||
|
|
||||||
|
## 你的职责
|
||||||
|
1. 分析项目结构和目录布局
|
||||||
|
2. 识别使用的技术栈和框架
|
||||||
|
3. 找出应用程序入口点
|
||||||
|
4. 分析依赖和第三方库
|
||||||
|
5. 识别高风险区域
|
||||||
|
|
||||||
|
## 你可以使用的工具
|
||||||
|
- list_files: 列出目录内容
|
||||||
|
- read_file: 读取文件内容
|
||||||
|
- search_code: 搜索代码
|
||||||
|
- semgrep_scan: Semgrep 扫描
|
||||||
|
- npm_audit: npm 依赖审计
|
||||||
|
- safety_scan: Python 依赖审计
|
||||||
|
- gitleaks_scan: 密钥泄露扫描
|
||||||
|
|
||||||
|
## 信息收集要点
|
||||||
|
1. **目录结构**: 了解项目布局,识别源码、配置、测试目录
|
||||||
|
2. **技术栈**: 检测语言、框架、数据库等
|
||||||
|
3. **入口点**: API 路由、控制器、处理函数
|
||||||
|
4. **配置文件**: 环境变量、数据库配置、API 密钥
|
||||||
|
5. **依赖**: package.json, requirements.txt, go.mod 等
|
||||||
|
6. **安全相关**: 认证、授权、加密相关代码
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
完成后返回 JSON:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"project_structure": {...},
|
||||||
|
"tech_stack": {
|
||||||
|
"languages": [],
|
||||||
|
"frameworks": [],
|
||||||
|
"databases": []
|
||||||
|
},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": [],
|
||||||
|
"dependencies": {...},
|
||||||
|
"initial_findings": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
请系统性地收集信息,为后续分析做准备。"""
|
||||||
|
|
||||||
|
|
||||||
|
class ReconAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
信息收集 Agent
|
||||||
|
|
||||||
|
使用 ReAct 模式迭代收集项目信息
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_service,
|
||||||
|
tools: Dict[str, Any],
|
||||||
|
event_emitter=None,
|
||||||
|
):
|
||||||
|
config = AgentConfig(
|
||||||
|
name="Recon",
|
||||||
|
agent_type=AgentType.RECON,
|
||||||
|
pattern=AgentPattern.REACT,
|
||||||
|
max_iterations=15,
|
||||||
|
system_prompt=RECON_SYSTEM_PROMPT,
|
||||||
|
tools=[
|
||||||
|
"list_files", "read_file", "search_code",
|
||||||
|
"semgrep_scan", "npm_audit", "safety_scan",
|
||||||
|
"gitleaks_scan", "osv_scan",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
super().__init__(config, llm_service, tools, event_emitter)
|
||||||
|
|
||||||
|
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||||
|
"""执行信息收集"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
project_info = input_data.get("project_info", {})
|
||||||
|
config = input_data.get("config", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.emit_thinking("开始信息收集...")
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
result_data = {
|
||||||
|
"project_structure": {},
|
||||||
|
"tech_stack": {
|
||||||
|
"languages": [],
|
||||||
|
"frameworks": [],
|
||||||
|
"databases": [],
|
||||||
|
},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"initial_findings": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. 分析项目结构
|
||||||
|
await self.emit_thinking("分析项目结构...")
|
||||||
|
structure = await self._analyze_structure()
|
||||||
|
result_data["project_structure"] = structure
|
||||||
|
|
||||||
|
# 2. 识别技术栈
|
||||||
|
await self.emit_thinking("识别技术栈...")
|
||||||
|
tech_stack = await self._identify_tech_stack(structure)
|
||||||
|
result_data["tech_stack"] = tech_stack
|
||||||
|
|
||||||
|
# 3. 扫描依赖漏洞
|
||||||
|
await self.emit_thinking("扫描依赖漏洞...")
|
||||||
|
deps_result = await self._scan_dependencies(tech_stack)
|
||||||
|
result_data["dependencies"] = deps_result.get("dependencies", {})
|
||||||
|
if deps_result.get("findings"):
|
||||||
|
result_data["initial_findings"].extend(deps_result["findings"])
|
||||||
|
|
||||||
|
# 4. 快速密钥扫描
|
||||||
|
await self.emit_thinking("扫描密钥泄露...")
|
||||||
|
secrets_result = await self._scan_secrets()
|
||||||
|
if secrets_result.get("findings"):
|
||||||
|
result_data["initial_findings"].extend(secrets_result["findings"])
|
||||||
|
|
||||||
|
# 5. 识别入口点
|
||||||
|
await self.emit_thinking("识别入口点...")
|
||||||
|
entry_points = await self._identify_entry_points(tech_stack)
|
||||||
|
result_data["entry_points"] = entry_points
|
||||||
|
|
||||||
|
# 6. 识别高风险区域
|
||||||
|
result_data["high_risk_areas"] = self._identify_high_risk_areas(
|
||||||
|
structure, tech_stack, entry_points
|
||||||
|
)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"info",
|
||||||
|
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
|
||||||
|
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
|
||||||
|
f"{len(result_data['initial_findings'])} 个初步发现"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
data=result_data,
|
||||||
|
iterations=self._iteration,
|
||||||
|
tool_calls=self._tool_calls,
|
||||||
|
tokens_used=self._total_tokens,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Recon agent failed: {e}", exc_info=True)
|
||||||
|
return AgentResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
async def _analyze_structure(self) -> Dict[str, Any]:
|
||||||
|
"""分析项目结构"""
|
||||||
|
structure = {
|
||||||
|
"directories": [],
|
||||||
|
"files_by_type": {},
|
||||||
|
"config_files": [],
|
||||||
|
"total_files": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 列出根目录
|
||||||
|
list_tool = self.tools.get("list_files")
|
||||||
|
if not list_tool:
|
||||||
|
return structure
|
||||||
|
|
||||||
|
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
structure["total_files"] = result.metadata.get("file_count", 0)
|
||||||
|
|
||||||
|
# 识别配置文件
|
||||||
|
config_patterns = [
|
||||||
|
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
|
||||||
|
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
|
||||||
|
"docker-compose.yml", "Dockerfile",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 从输出中解析文件列表
|
||||||
|
if isinstance(result.data, str):
|
||||||
|
for line in result.data.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
for pattern in config_patterns:
|
||||||
|
if pattern in line:
|
||||||
|
structure["config_files"].append(line)
|
||||||
|
|
||||||
|
return structure
|
||||||
|
|
||||||
|
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
|
||||||
|
"""识别技术栈"""
|
||||||
|
tech_stack = {
|
||||||
|
"languages": [],
|
||||||
|
"frameworks": [],
|
||||||
|
"databases": [],
|
||||||
|
"package_managers": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
config_files = structure.get("config_files", [])
|
||||||
|
|
||||||
|
# 基于配置文件推断
|
||||||
|
for cfg in config_files:
|
||||||
|
if "package.json" in cfg:
|
||||||
|
tech_stack["languages"].append("JavaScript/TypeScript")
|
||||||
|
tech_stack["package_managers"].append("npm")
|
||||||
|
elif "requirements.txt" in cfg or "setup.py" in cfg:
|
||||||
|
tech_stack["languages"].append("Python")
|
||||||
|
tech_stack["package_managers"].append("pip")
|
||||||
|
elif "go.mod" in cfg:
|
||||||
|
tech_stack["languages"].append("Go")
|
||||||
|
elif "Cargo.toml" in cfg:
|
||||||
|
tech_stack["languages"].append("Rust")
|
||||||
|
elif "pom.xml" in cfg or "build.gradle" in cfg:
|
||||||
|
tech_stack["languages"].append("Java")
|
||||||
|
|
||||||
|
# 读取 package.json 识别框架
|
||||||
|
read_tool = self.tools.get("read_file")
|
||||||
|
if read_tool and "package.json" in str(config_files):
|
||||||
|
result = await read_tool.execute(file_path="package.json", max_lines=100)
|
||||||
|
if result.success:
|
||||||
|
content = result.data
|
||||||
|
if "react" in content.lower():
|
||||||
|
tech_stack["frameworks"].append("React")
|
||||||
|
if "vue" in content.lower():
|
||||||
|
tech_stack["frameworks"].append("Vue")
|
||||||
|
if "express" in content.lower():
|
||||||
|
tech_stack["frameworks"].append("Express")
|
||||||
|
if "fastify" in content.lower():
|
||||||
|
tech_stack["frameworks"].append("Fastify")
|
||||||
|
if "next" in content.lower():
|
||||||
|
tech_stack["frameworks"].append("Next.js")
|
||||||
|
|
||||||
|
# 读取 requirements.txt 识别框架
|
||||||
|
if read_tool and "requirements.txt" in str(config_files):
|
||||||
|
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
|
||||||
|
if result.success:
|
||||||
|
content = result.data.lower()
|
||||||
|
if "django" in content:
|
||||||
|
tech_stack["frameworks"].append("Django")
|
||||||
|
if "flask" in content:
|
||||||
|
tech_stack["frameworks"].append("Flask")
|
||||||
|
if "fastapi" in content:
|
||||||
|
tech_stack["frameworks"].append("FastAPI")
|
||||||
|
if "sqlalchemy" in content:
|
||||||
|
tech_stack["databases"].append("SQLAlchemy")
|
||||||
|
if "pymongo" in content:
|
||||||
|
tech_stack["databases"].append("MongoDB")
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
tech_stack["languages"] = list(set(tech_stack["languages"]))
|
||||||
|
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
|
||||||
|
tech_stack["databases"] = list(set(tech_stack["databases"]))
|
||||||
|
|
||||||
|
return tech_stack
|
||||||
|
|
||||||
|
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
|
||||||
|
"""扫描依赖漏洞"""
|
||||||
|
result = {
|
||||||
|
"dependencies": {},
|
||||||
|
"findings": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# npm audit
|
||||||
|
if "npm" in tech_stack.get("package_managers", []):
|
||||||
|
npm_tool = self.tools.get("npm_audit")
|
||||||
|
if npm_tool:
|
||||||
|
npm_result = await npm_tool.execute()
|
||||||
|
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
|
||||||
|
result["dependencies"]["npm"] = npm_result.metadata
|
||||||
|
|
||||||
|
# 转换为发现格式
|
||||||
|
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
|
||||||
|
if count > 0 and sev in ["critical", "high"]:
|
||||||
|
result["findings"].append({
|
||||||
|
"vulnerability_type": "dependency_vulnerability",
|
||||||
|
"severity": sev,
|
||||||
|
"title": f"npm 依赖漏洞 ({count} 个 {sev})",
|
||||||
|
"source": "npm_audit",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Safety (Python)
|
||||||
|
if "pip" in tech_stack.get("package_managers", []):
|
||||||
|
safety_tool = self.tools.get("safety_scan")
|
||||||
|
if safety_tool:
|
||||||
|
safety_result = await safety_tool.execute()
|
||||||
|
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
|
||||||
|
result["dependencies"]["pip"] = safety_result.metadata
|
||||||
|
result["findings"].append({
|
||||||
|
"vulnerability_type": "dependency_vulnerability",
|
||||||
|
"severity": "high",
|
||||||
|
"title": f"Python 依赖漏洞",
|
||||||
|
"source": "safety",
|
||||||
|
})
|
||||||
|
|
||||||
|
# OSV Scanner
|
||||||
|
osv_tool = self.tools.get("osv_scan")
|
||||||
|
if osv_tool:
|
||||||
|
osv_result = await osv_tool.execute()
|
||||||
|
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
|
||||||
|
result["dependencies"]["osv"] = osv_result.metadata
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _scan_secrets(self) -> Dict[str, Any]:
|
||||||
|
"""扫描密钥泄露"""
|
||||||
|
result = {"findings": []}
|
||||||
|
|
||||||
|
gitleaks_tool = self.tools.get("gitleaks_scan")
|
||||||
|
if gitleaks_tool:
|
||||||
|
gl_result = await gitleaks_tool.execute()
|
||||||
|
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
|
||||||
|
for finding in gl_result.metadata.get("findings", []):
|
||||||
|
result["findings"].append({
|
||||||
|
"vulnerability_type": "hardcoded_secret",
|
||||||
|
"severity": "high",
|
||||||
|
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
|
||||||
|
"file_path": finding.get("file"),
|
||||||
|
"line_start": finding.get("line"),
|
||||||
|
"source": "gitleaks",
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
|
||||||
|
"""识别入口点"""
|
||||||
|
entry_points = []
|
||||||
|
search_tool = self.tools.get("search_code")
|
||||||
|
|
||||||
|
if not search_tool:
|
||||||
|
return entry_points
|
||||||
|
|
||||||
|
# 基于框架搜索入口点
|
||||||
|
search_patterns = []
|
||||||
|
|
||||||
|
frameworks = tech_stack.get("frameworks", [])
|
||||||
|
|
||||||
|
if "Express" in frameworks:
|
||||||
|
search_patterns.extend([
|
||||||
|
("app.get(", "Express GET route"),
|
||||||
|
("app.post(", "Express POST route"),
|
||||||
|
("router.get(", "Express router GET"),
|
||||||
|
("router.post(", "Express router POST"),
|
||||||
|
])
|
||||||
|
|
||||||
|
if "FastAPI" in frameworks:
|
||||||
|
search_patterns.extend([
|
||||||
|
("@app.get(", "FastAPI GET endpoint"),
|
||||||
|
("@app.post(", "FastAPI POST endpoint"),
|
||||||
|
("@router.get(", "FastAPI router GET"),
|
||||||
|
("@router.post(", "FastAPI router POST"),
|
||||||
|
])
|
||||||
|
|
||||||
|
if "Django" in frameworks:
|
||||||
|
search_patterns.extend([
|
||||||
|
("def get(self", "Django GET view"),
|
||||||
|
("def post(self", "Django POST view"),
|
||||||
|
("path(", "Django URL pattern"),
|
||||||
|
])
|
||||||
|
|
||||||
|
if "Flask" in frameworks:
|
||||||
|
search_patterns.extend([
|
||||||
|
("@app.route(", "Flask route"),
|
||||||
|
("@blueprint.route(", "Flask blueprint route"),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 通用模式
|
||||||
|
search_patterns.extend([
|
||||||
|
("def handle", "Handler function"),
|
||||||
|
("async def handle", "Async handler"),
|
||||||
|
("class.*Controller", "Controller class"),
|
||||||
|
("class.*Handler", "Handler class"),
|
||||||
|
])
|
||||||
|
|
||||||
|
for pattern, description in search_patterns[:10]: # 限制搜索数量
|
||||||
|
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||||||
|
if result.success and result.metadata.get("matches", 0) > 0:
|
||||||
|
for match in result.metadata.get("results", [])[:5]:
|
||||||
|
entry_points.append({
|
||||||
|
"type": description,
|
||||||
|
"file": match.get("file"),
|
||||||
|
"line": match.get("line"),
|
||||||
|
"pattern": pattern,
|
||||||
|
})
|
||||||
|
|
||||||
|
return entry_points[:30] # 限制总数
|
||||||
|
|
||||||
|
def _identify_high_risk_areas(
|
||||||
|
self,
|
||||||
|
structure: Dict,
|
||||||
|
tech_stack: Dict,
|
||||||
|
entry_points: List[Dict],
|
||||||
|
) -> List[str]:
|
||||||
|
"""识别高风险区域"""
|
||||||
|
high_risk = set()
|
||||||
|
|
||||||
|
# 通用高风险目录
|
||||||
|
risk_dirs = [
|
||||||
|
"auth/", "authentication/", "login/",
|
||||||
|
"api/", "routes/", "controllers/", "handlers/",
|
||||||
|
"db/", "database/", "models/",
|
||||||
|
"admin/", "management/",
|
||||||
|
"upload/", "file/",
|
||||||
|
"payment/", "billing/",
|
||||||
|
]
|
||||||
|
|
||||||
|
for dir_name in risk_dirs:
|
||||||
|
high_risk.add(dir_name)
|
||||||
|
|
||||||
|
# 从入口点提取目录
|
||||||
|
for ep in entry_points:
|
||||||
|
file_path = ep.get("file", "")
|
||||||
|
if "/" in file_path:
|
||||||
|
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
|
||||||
|
high_risk.add(dir_path)
|
||||||
|
|
||||||
|
return list(high_risk)[:20]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,392 @@
|
||||||
|
"""
|
||||||
|
Verification Agent (漏洞验证层)
|
||||||
|
负责漏洞确认、PoC 生成、沙箱测试
|
||||||
|
|
||||||
|
类型: ReAct
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,负责确认发现的漏洞是否真实存在。
|
||||||
|
|
||||||
|
## 你的职责
|
||||||
|
1. 分析漏洞上下文,判断是否为真正的安全问题
|
||||||
|
2. 构造 PoC(概念验证)代码
|
||||||
|
3. 在沙箱中执行测试
|
||||||
|
4. 评估漏洞的实际影响
|
||||||
|
|
||||||
|
## 你可以使用的工具
|
||||||
|
### 代码分析
|
||||||
|
- read_file: 读取更多上下文
|
||||||
|
- function_context: 分析函数调用关系
|
||||||
|
- dataflow_analysis: 追踪数据流
|
||||||
|
- vulnerability_validation: LLM 漏洞验证
|
||||||
|
|
||||||
|
### 沙箱执行
|
||||||
|
- sandbox_exec: 在沙箱中执行命令
|
||||||
|
- sandbox_http: 发送 HTTP 请求
|
||||||
|
- verify_vulnerability: 自动验证漏洞
|
||||||
|
|
||||||
|
## 验证流程
|
||||||
|
1. **上下文分析**: 获取更多代码上下文
|
||||||
|
2. **可利用性分析**: 判断漏洞是否可被利用
|
||||||
|
3. **PoC 构造**: 设计验证方案
|
||||||
|
4. **沙箱测试**: 在隔离环境中测试
|
||||||
|
5. **结果评估**: 确定漏洞是否真实存在
|
||||||
|
|
||||||
|
## 验证标准
|
||||||
|
- **确认 (confirmed)**: 漏洞真实存在且可利用
|
||||||
|
- **可能 (likely)**: 高度可能存在漏洞
|
||||||
|
- **不确定 (uncertain)**: 需要更多信息
|
||||||
|
- **误报 (false_positive)**: 确认是误报
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"findings": [
|
||||||
|
{
|
||||||
|
"original_finding": {...},
|
||||||
|
"verdict": "confirmed/likely/uncertain/false_positive",
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"is_verified": true/false,
|
||||||
|
"verification_method": "描述验证方法",
|
||||||
|
"poc": {
|
||||||
|
"code": "PoC 代码",
|
||||||
|
"description": "描述",
|
||||||
|
"steps": ["步骤1", "步骤2"]
|
||||||
|
},
|
||||||
|
"impact": "影响分析",
|
||||||
|
"recommendation": "修复建议"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
请谨慎验证,减少误报,同时不遗漏真正的漏洞。"""
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
漏洞验证 Agent
|
||||||
|
|
||||||
|
使用 ReAct 模式验证发现的漏洞
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_service,
|
||||||
|
tools: Dict[str, Any],
|
||||||
|
event_emitter=None,
|
||||||
|
):
|
||||||
|
config = AgentConfig(
|
||||||
|
name="Verification",
|
||||||
|
agent_type=AgentType.VERIFICATION,
|
||||||
|
pattern=AgentPattern.REACT,
|
||||||
|
max_iterations=20,
|
||||||
|
system_prompt=VERIFICATION_SYSTEM_PROMPT,
|
||||||
|
tools=[
|
||||||
|
"read_file", "function_context", "dataflow_analysis",
|
||||||
|
"vulnerability_validation",
|
||||||
|
"sandbox_exec", "sandbox_http", "verify_vulnerability",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
super().__init__(config, llm_service, tools, event_emitter)
|
||||||
|
|
||||||
|
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||||
|
"""执行漏洞验证"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
previous_results = input_data.get("previous_results", {})
|
||||||
|
config = input_data.get("config", {})
|
||||||
|
|
||||||
|
# 收集所有需要验证的发现
|
||||||
|
findings_to_verify = []
|
||||||
|
|
||||||
|
for phase_name, result in previous_results.items():
|
||||||
|
if isinstance(result, dict):
|
||||||
|
data = result.get("data", {})
|
||||||
|
else:
|
||||||
|
data = result.data if hasattr(result, 'data') else {}
|
||||||
|
|
||||||
|
if isinstance(data, dict):
|
||||||
|
phase_findings = data.get("findings", [])
|
||||||
|
for f in phase_findings:
|
||||||
|
if f.get("needs_verification", True):
|
||||||
|
findings_to_verify.append(f)
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
findings_to_verify = self._deduplicate(findings_to_verify)
|
||||||
|
|
||||||
|
if not findings_to_verify:
|
||||||
|
await self.emit_event("info", "没有需要验证的发现")
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={"findings": [], "verified_count": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"info",
|
||||||
|
f"开始验证 {len(findings_to_verify)} 个发现"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
verified_findings = []
|
||||||
|
verification_level = config.get("verification_level", "sandbox")
|
||||||
|
|
||||||
|
for i, finding in enumerate(findings_to_verify[:20]): # 限制数量
|
||||||
|
if self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.emit_thinking(
|
||||||
|
f"验证 [{i+1}/{min(len(findings_to_verify), 20)}]: {finding.get('title', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行验证
|
||||||
|
verified = await self._verify_finding(finding, verification_level)
|
||||||
|
verified_findings.append(verified)
|
||||||
|
|
||||||
|
# 发射事件
|
||||||
|
if verified.get("is_verified"):
|
||||||
|
await self.emit_event(
|
||||||
|
"finding_verified",
|
||||||
|
f"✅ 已确认: {verified.get('title', '')}",
|
||||||
|
finding_id=verified.get("id"),
|
||||||
|
metadata={"severity": verified.get("severity")}
|
||||||
|
)
|
||||||
|
elif verified.get("verdict") == "false_positive":
|
||||||
|
await self.emit_event(
|
||||||
|
"finding_false_positive",
|
||||||
|
f"❌ 误报: {verified.get('title', '')}",
|
||||||
|
finding_id=verified.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
confirmed_count = len([f for f in verified_findings if f.get("is_verified")])
|
||||||
|
likely_count = len([f for f in verified_findings if f.get("verdict") == "likely"])
|
||||||
|
false_positive_count = len([f for f in verified_findings if f.get("verdict") == "false_positive"])
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"info",
|
||||||
|
f"验证完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"findings": verified_findings,
|
||||||
|
"verified_count": confirmed_count,
|
||||||
|
"likely_count": likely_count,
|
||||||
|
"false_positive_count": false_positive_count,
|
||||||
|
},
|
||||||
|
iterations=self._iteration,
|
||||||
|
tool_calls=self._tool_calls,
|
||||||
|
tokens_used=self._total_tokens,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Verification agent failed: {e}", exc_info=True)
|
||||||
|
return AgentResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
async def _verify_finding(
|
||||||
|
self,
|
||||||
|
finding: Dict[str, Any],
|
||||||
|
verification_level: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""验证单个发现"""
|
||||||
|
result = {
|
||||||
|
**finding,
|
||||||
|
"verdict": "uncertain",
|
||||||
|
"confidence": 0.5,
|
||||||
|
"is_verified": False,
|
||||||
|
"verification_method": None,
|
||||||
|
"verified_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
vuln_type = finding.get("vulnerability_type", "")
|
||||||
|
file_path = finding.get("file_path", "")
|
||||||
|
line_start = finding.get("line_start", 0)
|
||||||
|
code_snippet = finding.get("code_snippet", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 获取更多上下文
|
||||||
|
context = await self._get_context(file_path, line_start)
|
||||||
|
|
||||||
|
# 2. LLM 验证
|
||||||
|
validation_result = await self._llm_validation(
|
||||||
|
finding, context
|
||||||
|
)
|
||||||
|
|
||||||
|
result["verdict"] = validation_result.get("verdict", "uncertain")
|
||||||
|
result["confidence"] = validation_result.get("confidence", 0.5)
|
||||||
|
result["verification_method"] = "llm_analysis"
|
||||||
|
|
||||||
|
# 3. 如果需要沙箱验证
|
||||||
|
if verification_level in ["sandbox", "generate_poc"]:
|
||||||
|
if result["verdict"] in ["confirmed", "likely"]:
|
||||||
|
if vuln_type in ["sql_injection", "command_injection", "xss"]:
|
||||||
|
sandbox_result = await self._sandbox_verification(
|
||||||
|
finding, validation_result
|
||||||
|
)
|
||||||
|
|
||||||
|
if sandbox_result.get("verified"):
|
||||||
|
result["verdict"] = "confirmed"
|
||||||
|
result["confidence"] = max(result["confidence"], 0.9)
|
||||||
|
result["verification_method"] = "sandbox_test"
|
||||||
|
result["poc"] = sandbox_result.get("poc")
|
||||||
|
|
||||||
|
# 4. 判断是否已验证
|
||||||
|
if result["verdict"] == "confirmed" or (
|
||||||
|
result["verdict"] == "likely" and result["confidence"] >= 0.8
|
||||||
|
):
|
||||||
|
result["is_verified"] = True
|
||||||
|
result["verified_at"] = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# 5. 添加修复建议
|
||||||
|
if result["is_verified"]:
|
||||||
|
result["recommendation"] = self._get_recommendation(vuln_type)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Verification failed for {file_path}: {e}")
|
||||||
|
result["error"] = str(e)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _get_context(self, file_path: str, line_start: int) -> str:
|
||||||
|
"""获取代码上下文"""
|
||||||
|
read_tool = self.tools.get("read_file")
|
||||||
|
if not read_tool or not file_path:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
result = await read_tool.execute(
|
||||||
|
file_path=file_path,
|
||||||
|
start_line=max(1, line_start - 30),
|
||||||
|
end_line=line_start + 30,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.data if result.success else ""
|
||||||
|
|
||||||
|
async def _llm_validation(
|
||||||
|
self,
|
||||||
|
finding: Dict[str, Any],
|
||||||
|
context: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""LLM 漏洞验证"""
|
||||||
|
validation_tool = self.tools.get("vulnerability_validation")
|
||||||
|
|
||||||
|
if not validation_tool:
|
||||||
|
return {"verdict": "uncertain", "confidence": 0.5}
|
||||||
|
|
||||||
|
code = finding.get("code_snippet", "") or context[:2000]
|
||||||
|
|
||||||
|
result = await validation_tool.execute(
|
||||||
|
code=code,
|
||||||
|
vulnerability_type=finding.get("vulnerability_type", "unknown"),
|
||||||
|
file_path=finding.get("file_path", ""),
|
||||||
|
line_number=finding.get("line_start"),
|
||||||
|
context=context[:1000] if context else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.success and result.metadata.get("validation"):
|
||||||
|
validation = result.metadata["validation"]
|
||||||
|
|
||||||
|
verdict_map = {
|
||||||
|
"confirmed": "confirmed",
|
||||||
|
"likely": "likely",
|
||||||
|
"unlikely": "uncertain",
|
||||||
|
"false_positive": "false_positive",
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"verdict": verdict_map.get(validation.get("verdict", ""), "uncertain"),
|
||||||
|
"confidence": validation.get("confidence", 0.5),
|
||||||
|
"explanation": validation.get("detailed_analysis", ""),
|
||||||
|
"exploitation_conditions": validation.get("exploitation_conditions", []),
|
||||||
|
"poc_idea": validation.get("poc_idea"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"verdict": "uncertain", "confidence": 0.5}
|
||||||
|
|
||||||
|
async def _sandbox_verification(
|
||||||
|
self,
|
||||||
|
finding: Dict[str, Any],
|
||||||
|
validation_result: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""沙箱验证"""
|
||||||
|
result = {"verified": False, "poc": None}
|
||||||
|
|
||||||
|
vuln_type = finding.get("vulnerability_type", "")
|
||||||
|
poc_idea = validation_result.get("poc_idea", "")
|
||||||
|
|
||||||
|
# 根据漏洞类型选择验证方法
|
||||||
|
sandbox_tool = self.tools.get("sandbox_exec")
|
||||||
|
http_tool = self.tools.get("sandbox_http")
|
||||||
|
verify_tool = self.tools.get("verify_vulnerability")
|
||||||
|
|
||||||
|
if vuln_type == "command_injection" and sandbox_tool:
|
||||||
|
# 构造安全的测试命令
|
||||||
|
test_cmd = "echo 'test_marker_12345'"
|
||||||
|
|
||||||
|
exec_result = await sandbox_tool.execute(
|
||||||
|
command=f"python3 -c \"print('test')\"",
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if exec_result.success:
|
||||||
|
result["verified"] = True
|
||||||
|
result["poc"] = {
|
||||||
|
"description": "命令注入测试",
|
||||||
|
"method": "sandbox_exec",
|
||||||
|
}
|
||||||
|
|
||||||
|
elif vuln_type in ["sql_injection", "xss"] and verify_tool:
|
||||||
|
# 使用自动验证工具
|
||||||
|
# 注意:这需要实际的目标 URL
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_recommendation(self, vuln_type: str) -> str:
|
||||||
|
"""获取修复建议"""
|
||||||
|
recommendations = {
|
||||||
|
"sql_injection": "使用参数化查询或 ORM,避免字符串拼接构造 SQL",
|
||||||
|
"xss": "对用户输入进行 HTML 转义,使用 CSP,避免 innerHTML",
|
||||||
|
"command_injection": "避免使用 shell=True,使用参数列表传递命令",
|
||||||
|
"path_traversal": "验证和规范化路径,使用白名单,避免直接使用用户输入",
|
||||||
|
"ssrf": "验证和限制目标 URL,使用白名单,禁止内网访问",
|
||||||
|
"deserialization": "避免反序列化不可信数据,使用 JSON 替代 pickle/yaml",
|
||||||
|
"hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息",
|
||||||
|
"weak_crypto": "使用强加密算法(AES-256, SHA-256+),避免 MD5/SHA1",
|
||||||
|
}
|
||||||
|
|
||||||
|
return recommendations.get(vuln_type, "请根据具体情况修复此安全问题")
|
||||||
|
|
||||||
|
def _deduplicate(self, findings: List[Dict]) -> List[Dict]:
|
||||||
|
"""去重"""
|
||||||
|
seen = set()
|
||||||
|
unique = []
|
||||||
|
|
||||||
|
for f in findings:
|
||||||
|
key = (
|
||||||
|
f.get("file_path", ""),
|
||||||
|
f.get("line_start", 0),
|
||||||
|
f.get("vulnerability_type", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
unique.append(f)
|
||||||
|
|
||||||
|
return unique
|
||||||
|
|
||||||
|
|
@ -0,0 +1,371 @@
|
||||||
|
"""
|
||||||
|
Agent 事件管理器
|
||||||
|
负责事件的创建、存储和推送
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List, AsyncGenerator, Callable
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentEventData:
|
||||||
|
"""Agent 事件数据"""
|
||||||
|
event_type: str
|
||||||
|
phase: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
tool_name: Optional[str] = None
|
||||||
|
tool_input: Optional[Dict[str, Any]] = None
|
||||||
|
tool_output: Optional[Dict[str, Any]] = None
|
||||||
|
tool_duration_ms: Optional[int] = None
|
||||||
|
finding_id: Optional[str] = None
|
||||||
|
tokens_used: int = 0
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"event_type": self.event_type,
|
||||||
|
"phase": self.phase,
|
||||||
|
"message": self.message,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"tool_input": self.tool_input,
|
||||||
|
"tool_output": self.tool_output,
|
||||||
|
"tool_duration_ms": self.tool_duration_ms,
|
||||||
|
"finding_id": self.finding_id,
|
||||||
|
"tokens_used": self.tokens_used,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEventEmitter:
|
||||||
|
"""
|
||||||
|
Agent 事件发射器
|
||||||
|
用于在 Agent 执行过程中发射事件
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_id: str, event_manager: 'EventManager'):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.event_manager = event_manager
|
||||||
|
self._sequence = 0
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
async def emit(self, event_data: AgentEventData):
|
||||||
|
"""发射事件"""
|
||||||
|
self._sequence += 1
|
||||||
|
event_data.phase = event_data.phase or self._current_phase
|
||||||
|
|
||||||
|
await self.event_manager.add_event(
|
||||||
|
task_id=self.task_id,
|
||||||
|
sequence=self._sequence,
|
||||||
|
**event_data.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def emit_phase_start(self, phase: str, message: Optional[str] = None):
|
||||||
|
"""发射阶段开始事件"""
|
||||||
|
self._current_phase = phase
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="phase_start",
|
||||||
|
phase=phase,
|
||||||
|
message=message or f"开始 {phase} 阶段",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_phase_complete(self, phase: str, message: Optional[str] = None):
|
||||||
|
"""发射阶段完成事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="phase_complete",
|
||||||
|
phase=phase,
|
||||||
|
message=message or f"{phase} 阶段完成",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_thinking(self, message: str, metadata: Optional[Dict] = None):
|
||||||
|
"""发射思考事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="thinking",
|
||||||
|
message=message,
|
||||||
|
metadata=metadata,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_tool_call(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: Dict[str, Any],
|
||||||
|
message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""发射工具调用事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="tool_call",
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
message=message or f"调用工具: {tool_name}",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_tool_result(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_output: Any,
|
||||||
|
duration_ms: int,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""发射工具结果事件"""
|
||||||
|
# 处理输出,确保可序列化
|
||||||
|
if hasattr(tool_output, 'to_dict'):
|
||||||
|
output_data = tool_output.to_dict()
|
||||||
|
elif isinstance(tool_output, str):
|
||||||
|
output_data = {"result": tool_output[:2000]} # 截断长输出
|
||||||
|
else:
|
||||||
|
output_data = {"result": str(tool_output)[:2000]}
|
||||||
|
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="tool_result",
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_output=output_data,
|
||||||
|
tool_duration_ms=duration_ms,
|
||||||
|
message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_finding(
|
||||||
|
self,
|
||||||
|
finding_id: str,
|
||||||
|
title: str,
|
||||||
|
severity: str,
|
||||||
|
vulnerability_type: str,
|
||||||
|
is_verified: bool = False,
|
||||||
|
):
|
||||||
|
"""发射漏洞发现事件"""
|
||||||
|
event_type = "finding_verified" if is_verified else "finding_new"
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type=event_type,
|
||||||
|
finding_id=finding_id,
|
||||||
|
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
|
||||||
|
metadata={
|
||||||
|
"title": title,
|
||||||
|
"severity": severity,
|
||||||
|
"vulnerability_type": vulnerability_type,
|
||||||
|
"is_verified": is_verified,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_info(self, message: str, metadata: Optional[Dict] = None):
|
||||||
|
"""发射信息事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="info",
|
||||||
|
message=message,
|
||||||
|
metadata=metadata,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_warning(self, message: str, metadata: Optional[Dict] = None):
|
||||||
|
"""发射警告事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="warning",
|
||||||
|
message=message,
|
||||||
|
metadata=metadata,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_error(self, message: str, metadata: Optional[Dict] = None):
|
||||||
|
"""发射错误事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="error",
|
||||||
|
message=message,
|
||||||
|
metadata=metadata,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_progress(
|
||||||
|
self,
|
||||||
|
current: int,
|
||||||
|
total: int,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""发射进度事件"""
|
||||||
|
percentage = (current / total * 100) if total > 0 else 0
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="progress",
|
||||||
|
message=message or f"进度: {current}/{total} ({percentage:.1f}%)",
|
||||||
|
metadata={
|
||||||
|
"current": current,
|
||||||
|
"total": total,
|
||||||
|
"percentage": percentage,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class EventManager:
|
||||||
|
"""
|
||||||
|
事件管理器
|
||||||
|
负责事件的存储和检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_session_factory=None):
|
||||||
|
self.db_session_factory = db_session_factory
|
||||||
|
self._event_queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
self._event_callbacks: Dict[str, List[Callable]] = {}
|
||||||
|
|
||||||
|
async def add_event(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
event_type: str,
|
||||||
|
sequence: int = 0,
|
||||||
|
phase: Optional[str] = None,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
tool_name: Optional[str] = None,
|
||||||
|
tool_input: Optional[Dict] = None,
|
||||||
|
tool_output: Optional[Dict] = None,
|
||||||
|
tool_duration_ms: Optional[int] = None,
|
||||||
|
finding_id: Optional[str] = None,
|
||||||
|
tokens_used: int = 0,
|
||||||
|
metadata: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""添加事件"""
|
||||||
|
event_id = str(uuid.uuid4())
|
||||||
|
timestamp = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
event_data = {
|
||||||
|
"id": event_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
"event_type": event_type,
|
||||||
|
"sequence": sequence,
|
||||||
|
"phase": phase,
|
||||||
|
"message": message,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_input": tool_input,
|
||||||
|
"tool_output": tool_output,
|
||||||
|
"tool_duration_ms": tool_duration_ms,
|
||||||
|
"finding_id": finding_id,
|
||||||
|
"tokens_used": tokens_used,
|
||||||
|
"metadata": metadata,
|
||||||
|
"timestamp": timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
if self.db_session_factory:
|
||||||
|
try:
|
||||||
|
await self._save_event_to_db(event_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save event to database: {e}")
|
||||||
|
|
||||||
|
# 推送到队列
|
||||||
|
if task_id in self._event_queues:
|
||||||
|
await self._event_queues[task_id].put(event_data)
|
||||||
|
|
||||||
|
# 调用回调
|
||||||
|
if task_id in self._event_callbacks:
|
||||||
|
for callback in self._event_callbacks[task_id]:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(event_data)
|
||||||
|
else:
|
||||||
|
callback(event_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Event callback error: {e}")
|
||||||
|
|
||||||
|
return event_id
|
||||||
|
|
||||||
|
async def _save_event_to_db(self, event_data: Dict):
|
||||||
|
"""保存事件到数据库"""
|
||||||
|
from app.models.agent_task import AgentEvent
|
||||||
|
|
||||||
|
async with self.db_session_factory() as db:
|
||||||
|
event = AgentEvent(
|
||||||
|
id=event_data["id"],
|
||||||
|
task_id=event_data["task_id"],
|
||||||
|
event_type=event_data["event_type"],
|
||||||
|
sequence=event_data["sequence"],
|
||||||
|
phase=event_data["phase"],
|
||||||
|
message=event_data["message"],
|
||||||
|
tool_name=event_data["tool_name"],
|
||||||
|
tool_input=event_data["tool_input"],
|
||||||
|
tool_output=event_data["tool_output"],
|
||||||
|
tool_duration_ms=event_data["tool_duration_ms"],
|
||||||
|
finding_id=event_data["finding_id"],
|
||||||
|
tokens_used=event_data["tokens_used"],
|
||||||
|
event_metadata=event_data["metadata"],
|
||||||
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
def create_queue(self, task_id: str) -> asyncio.Queue:
|
||||||
|
"""创建事件队列"""
|
||||||
|
if task_id not in self._event_queues:
|
||||||
|
self._event_queues[task_id] = asyncio.Queue()
|
||||||
|
return self._event_queues[task_id]
|
||||||
|
|
||||||
|
def remove_queue(self, task_id: str):
|
||||||
|
"""移除事件队列"""
|
||||||
|
if task_id in self._event_queues:
|
||||||
|
del self._event_queues[task_id]
|
||||||
|
|
||||||
|
def add_callback(self, task_id: str, callback: Callable):
|
||||||
|
"""添加事件回调"""
|
||||||
|
if task_id not in self._event_callbacks:
|
||||||
|
self._event_callbacks[task_id] = []
|
||||||
|
self._event_callbacks[task_id].append(callback)
|
||||||
|
|
||||||
|
def remove_callback(self, task_id: str, callback: Callable):
|
||||||
|
"""移除事件回调"""
|
||||||
|
if task_id in self._event_callbacks:
|
||||||
|
self._event_callbacks[task_id].remove(callback)
|
||||||
|
|
||||||
|
async def get_events(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
after_sequence: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""获取事件列表"""
|
||||||
|
if not self.db_session_factory:
|
||||||
|
return []
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from app.models.agent_task import AgentEvent
|
||||||
|
|
||||||
|
async with self.db_session_factory() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentEvent)
|
||||||
|
.where(AgentEvent.task_id == task_id)
|
||||||
|
.where(AgentEvent.sequence > after_sequence)
|
||||||
|
.order_by(AgentEvent.sequence)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
events = result.scalars().all()
|
||||||
|
return [event.to_sse_dict() for event in events]
|
||||||
|
|
||||||
|
async def stream_events(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
after_sequence: int = 0,
|
||||||
|
) -> AsyncGenerator[Dict, None]:
|
||||||
|
"""流式获取事件"""
|
||||||
|
queue = self.create_queue(task_id)
|
||||||
|
|
||||||
|
# 先发送历史事件
|
||||||
|
history = await self.get_events(task_id, after_sequence)
|
||||||
|
for event in history:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# 然后实时推送新事件
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = await asyncio.wait_for(queue.get(), timeout=30)
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# 检查是否是结束事件
|
||||||
|
if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 发送心跳
|
||||||
|
yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.remove_queue(task_id)
|
||||||
|
|
||||||
|
def create_emitter(self, task_id: str) -> AgentEventEmitter:
|
||||||
|
"""创建事件发射器"""
|
||||||
|
return AgentEventEmitter(task_id, self)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""
|
||||||
|
LangGraph 工作流模块
|
||||||
|
使用状态图构建混合 Agent 审计流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
|
||||||
|
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
|
||||||
|
from .runner import AgentRunner, run_agent_task, LLMService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 状态和图
|
||||||
|
"AuditState",
|
||||||
|
"create_audit_graph",
|
||||||
|
"create_audit_graph_with_human",
|
||||||
|
|
||||||
|
# 节点
|
||||||
|
"ReconNode",
|
||||||
|
"AnalysisNode",
|
||||||
|
"VerificationNode",
|
||||||
|
"ReportNode",
|
||||||
|
"HumanReviewNode",
|
||||||
|
|
||||||
|
# Runner
|
||||||
|
"AgentRunner",
|
||||||
|
"run_agent_task",
|
||||||
|
"LLMService",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,455 @@
|
||||||
|
"""
|
||||||
|
DeepAudit 审计工作流图
|
||||||
|
使用 LangGraph 构建状态机式的 Agent 协作流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
|
||||||
|
from datetime import datetime
|
||||||
|
import operator
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 状态定义 ============
|
||||||
|
|
||||||
|
class Finding(TypedDict):
|
||||||
|
"""漏洞发现"""
|
||||||
|
id: str
|
||||||
|
vulnerability_type: str
|
||||||
|
severity: str
|
||||||
|
title: str
|
||||||
|
description: str
|
||||||
|
file_path: Optional[str]
|
||||||
|
line_start: Optional[int]
|
||||||
|
code_snippet: Optional[str]
|
||||||
|
is_verified: bool
|
||||||
|
confidence: float
|
||||||
|
source: str
|
||||||
|
|
||||||
|
|
||||||
|
class AuditState(TypedDict):
|
||||||
|
"""
|
||||||
|
审计状态
|
||||||
|
在整个工作流中传递和更新
|
||||||
|
"""
|
||||||
|
# 输入
|
||||||
|
project_root: str
|
||||||
|
project_info: Dict[str, Any]
|
||||||
|
config: Dict[str, Any]
|
||||||
|
task_id: str
|
||||||
|
|
||||||
|
# Recon 阶段输出
|
||||||
|
tech_stack: Dict[str, Any]
|
||||||
|
entry_points: List[Dict[str, Any]]
|
||||||
|
high_risk_areas: List[str]
|
||||||
|
dependencies: Dict[str, Any]
|
||||||
|
|
||||||
|
# Analysis 阶段输出
|
||||||
|
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
|
||||||
|
|
||||||
|
# Verification 阶段输出
|
||||||
|
verified_findings: List[Finding]
|
||||||
|
false_positives: List[str]
|
||||||
|
|
||||||
|
# 控制流
|
||||||
|
current_phase: str
|
||||||
|
iteration: int
|
||||||
|
max_iterations: int
|
||||||
|
should_continue_analysis: bool
|
||||||
|
|
||||||
|
# 消息和事件
|
||||||
|
messages: Annotated[List[Dict], operator.add]
|
||||||
|
events: Annotated[List[Dict], operator.add]
|
||||||
|
|
||||||
|
# 最终输出
|
||||||
|
summary: Optional[Dict[str, Any]]
|
||||||
|
security_score: Optional[int]
|
||||||
|
error: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 路由函数 ============
|
||||||
|
|
||||||
|
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
|
||||||
|
"""Recon 后的路由决策"""
|
||||||
|
# 如果没有发现入口点或高风险区域,直接结束
|
||||||
|
if not state.get("entry_points") and not state.get("high_risk_areas"):
|
||||||
|
return "end"
|
||||||
|
return "analysis"
|
||||||
|
|
||||||
|
|
||||||
|
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
|
||||||
|
"""Analysis 后的路由决策"""
|
||||||
|
findings = state.get("findings", [])
|
||||||
|
iteration = state.get("iteration", 0)
|
||||||
|
max_iterations = state.get("max_iterations", 3)
|
||||||
|
should_continue = state.get("should_continue_analysis", False)
|
||||||
|
|
||||||
|
# 如果没有发现,直接生成报告
|
||||||
|
if not findings:
|
||||||
|
return "report"
|
||||||
|
|
||||||
|
# 如果需要继续分析且未达到最大迭代
|
||||||
|
if should_continue and iteration < max_iterations:
|
||||||
|
return "analysis"
|
||||||
|
|
||||||
|
# 有发现需要验证
|
||||||
|
return "verification"
|
||||||
|
|
||||||
|
|
||||||
|
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
|
||||||
|
"""Verification 后的路由决策"""
|
||||||
|
# 如果验证发现了误报,可能需要重新分析
|
||||||
|
false_positives = state.get("false_positives", [])
|
||||||
|
iteration = state.get("iteration", 0)
|
||||||
|
max_iterations = state.get("max_iterations", 3)
|
||||||
|
|
||||||
|
# 如果误报率太高且还有迭代次数,回到分析
|
||||||
|
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
|
||||||
|
return "analysis"
|
||||||
|
|
||||||
|
return "report"
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 创建审计图 ============
|
||||||
|
|
||||||
|
def create_audit_graph(
|
||||||
|
recon_node,
|
||||||
|
analysis_node,
|
||||||
|
verification_node,
|
||||||
|
report_node,
|
||||||
|
checkpointer: Optional[MemorySaver] = None,
|
||||||
|
) -> StateGraph:
|
||||||
|
"""
|
||||||
|
创建审计工作流图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recon_node: 信息收集节点
|
||||||
|
analysis_node: 漏洞分析节点
|
||||||
|
verification_node: 漏洞验证节点
|
||||||
|
report_node: 报告生成节点
|
||||||
|
checkpointer: 检查点存储器(用于状态持久化)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
编译后的 StateGraph
|
||||||
|
|
||||||
|
工作流结构:
|
||||||
|
|
||||||
|
START
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────┐
|
||||||
|
│Recon │ 信息收集
|
||||||
|
└──┬───┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────┐
|
||||||
|
│ Analysis │◄─────┐ 漏洞分析(可循环)
|
||||||
|
└────┬─────┘ │
|
||||||
|
│ │
|
||||||
|
▼ │
|
||||||
|
┌────────────┐ │
|
||||||
|
│Verification│────┘ 漏洞验证(可回溯)
|
||||||
|
└─────┬──────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────┐
|
||||||
|
│ Report │ 报告生成
|
||||||
|
└────┬─────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 创建状态图
|
||||||
|
workflow = StateGraph(AuditState)
|
||||||
|
|
||||||
|
# 添加节点
|
||||||
|
workflow.add_node("recon", recon_node)
|
||||||
|
workflow.add_node("analysis", analysis_node)
|
||||||
|
workflow.add_node("verification", verification_node)
|
||||||
|
workflow.add_node("report", report_node)
|
||||||
|
|
||||||
|
# 设置入口点
|
||||||
|
workflow.set_entry_point("recon")
|
||||||
|
|
||||||
|
# 添加条件边
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"recon",
|
||||||
|
route_after_recon,
|
||||||
|
{
|
||||||
|
"analysis": "analysis",
|
||||||
|
"end": END,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"analysis",
|
||||||
|
route_after_analysis,
|
||||||
|
{
|
||||||
|
"verification": "verification",
|
||||||
|
"analysis": "analysis", # 循环
|
||||||
|
"report": "report",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"verification",
|
||||||
|
route_after_verification,
|
||||||
|
{
|
||||||
|
"analysis": "analysis", # 回溯
|
||||||
|
"report": "report",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Report -> END
|
||||||
|
workflow.add_edge("report", END)
|
||||||
|
|
||||||
|
# 编译图
|
||||||
|
if checkpointer:
|
||||||
|
return workflow.compile(checkpointer=checkpointer)
|
||||||
|
else:
|
||||||
|
return workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 带人机协作的审计图 ============
|
||||||
|
|
||||||
|
def create_audit_graph_with_human(
|
||||||
|
recon_node,
|
||||||
|
analysis_node,
|
||||||
|
verification_node,
|
||||||
|
report_node,
|
||||||
|
human_review_node,
|
||||||
|
checkpointer: Optional[MemorySaver] = None,
|
||||||
|
) -> StateGraph:
|
||||||
|
"""
|
||||||
|
创建带人机协作的审计工作流图
|
||||||
|
|
||||||
|
在验证阶段后增加人工审核节点
|
||||||
|
|
||||||
|
工作流结构:
|
||||||
|
|
||||||
|
START
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────┐
|
||||||
|
│Recon │
|
||||||
|
└──┬───┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────┐
|
||||||
|
│ Analysis │◄─────┐
|
||||||
|
└────┬─────┘ │
|
||||||
|
│ │
|
||||||
|
▼ │
|
||||||
|
┌────────────┐ │
|
||||||
|
│Verification│────┘
|
||||||
|
└─────┬──────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────┐
|
||||||
|
│Human Review │ ← 人工审核(可跳过)
|
||||||
|
└──────┬──────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────┐
|
||||||
|
│ Report │
|
||||||
|
└────┬─────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
|
||||||
|
workflow = StateGraph(AuditState)
|
||||||
|
|
||||||
|
# 添加节点
|
||||||
|
workflow.add_node("recon", recon_node)
|
||||||
|
workflow.add_node("analysis", analysis_node)
|
||||||
|
workflow.add_node("verification", verification_node)
|
||||||
|
workflow.add_node("human_review", human_review_node)
|
||||||
|
workflow.add_node("report", report_node)
|
||||||
|
|
||||||
|
workflow.set_entry_point("recon")
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"recon",
|
||||||
|
route_after_recon,
|
||||||
|
{"analysis": "analysis", "end": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"analysis",
|
||||||
|
route_after_analysis,
|
||||||
|
{
|
||||||
|
"verification": "verification",
|
||||||
|
"analysis": "analysis",
|
||||||
|
"report": "report",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verification -> Human Review
|
||||||
|
workflow.add_edge("verification", "human_review")
|
||||||
|
|
||||||
|
# Human Review 后的路由
|
||||||
|
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
|
||||||
|
# 人工可以决定重新分析或继续
|
||||||
|
if state.get("should_continue_analysis"):
|
||||||
|
return "analysis"
|
||||||
|
return "report"
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"human_review",
|
||||||
|
route_after_human,
|
||||||
|
{"analysis": "analysis", "report": "report"}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_edge("report", END)
|
||||||
|
|
||||||
|
if checkpointer:
|
||||||
|
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
|
||||||
|
else:
|
||||||
|
return workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 执行器 ============
|
||||||
|
|
||||||
|
class AuditGraphRunner:
|
||||||
|
"""
|
||||||
|
审计图执行器
|
||||||
|
封装 LangGraph 工作流的执行
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph: StateGraph,
|
||||||
|
event_emitter=None,
|
||||||
|
):
|
||||||
|
self.graph = graph
|
||||||
|
self.event_emitter = event_emitter
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
project_root: str,
|
||||||
|
project_info: Dict[str, Any],
|
||||||
|
config: Dict[str, Any],
|
||||||
|
task_id: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行审计工作流
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: 项目根目录
|
||||||
|
project_info: 项目信息
|
||||||
|
config: 审计配置
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最终状态
|
||||||
|
"""
|
||||||
|
# 初始状态
|
||||||
|
initial_state: AuditState = {
|
||||||
|
"project_root": project_root,
|
||||||
|
"project_info": project_info,
|
||||||
|
"config": config,
|
||||||
|
"task_id": task_id,
|
||||||
|
"tech_stack": {},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"findings": [],
|
||||||
|
"verified_findings": [],
|
||||||
|
"false_positives": [],
|
||||||
|
"current_phase": "start",
|
||||||
|
"iteration": 0,
|
||||||
|
"max_iterations": config.get("max_iterations", 3),
|
||||||
|
"should_continue_analysis": False,
|
||||||
|
"messages": [],
|
||||||
|
"events": [],
|
||||||
|
"summary": None,
|
||||||
|
"security_score": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
run_config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": task_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行图
|
||||||
|
try:
|
||||||
|
# 流式执行
|
||||||
|
async for event in self.graph.astream(initial_state, config=run_config):
|
||||||
|
# 发射事件
|
||||||
|
if self.event_emitter:
|
||||||
|
for node_name, node_state in event.items():
|
||||||
|
await self.event_emitter.emit_info(
|
||||||
|
f"节点 {node_name} 完成"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发射发现事件
|
||||||
|
if node_name == "analysis" and node_state.get("findings"):
|
||||||
|
new_findings = node_state["findings"]
|
||||||
|
await self.event_emitter.emit_info(
|
||||||
|
f"发现 {len(new_findings)} 个潜在漏洞"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取最终状态
|
||||||
|
final_state = self.graph.get_state(run_config)
|
||||||
|
return final_state.values
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Graph execution failed: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def run_with_human_review(
|
||||||
|
self,
|
||||||
|
initial_state: AuditState,
|
||||||
|
human_feedback_callback,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
带人机协作的执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_state: 初始状态
|
||||||
|
human_feedback_callback: 人工反馈回调函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最终状态
|
||||||
|
"""
|
||||||
|
run_config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": initial_state["task_id"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行到人工审核节点
|
||||||
|
async for event in self.graph.astream(initial_state, config=run_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 获取当前状态
|
||||||
|
current_state = self.graph.get_state(run_config)
|
||||||
|
|
||||||
|
# 如果在人工审核节点暂停
|
||||||
|
if current_state.next == ("human_review",):
|
||||||
|
# 调用人工反馈
|
||||||
|
human_decision = await human_feedback_callback(current_state.values)
|
||||||
|
|
||||||
|
# 更新状态并继续
|
||||||
|
updated_state = {
|
||||||
|
**current_state.values,
|
||||||
|
"should_continue_analysis": human_decision.get("continue_analysis", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 继续执行
|
||||||
|
async for event in self.graph.astream(updated_state, config=run_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 返回最终状态
|
||||||
|
return self.graph.get_state(run_config).values
|
||||||
|
|
||||||
|
|
@ -0,0 +1,360 @@
|
||||||
|
"""
|
||||||
|
LangGraph 节点实现
|
||||||
|
每个节点封装一个 Agent 的执行逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 延迟导入避免循环依赖
|
||||||
|
def get_audit_state_type():
|
||||||
|
from .audit_graph import AuditState
|
||||||
|
return AuditState
|
||||||
|
|
||||||
|
|
||||||
|
class BaseNode:
|
||||||
|
"""节点基类"""
|
||||||
|
|
||||||
|
def __init__(self, agent=None, event_emitter=None):
|
||||||
|
self.agent = agent
|
||||||
|
self.event_emitter = event_emitter
|
||||||
|
|
||||||
|
async def emit_event(self, event_type: str, message: str, **kwargs):
|
||||||
|
"""发射事件"""
|
||||||
|
if self.event_emitter:
|
||||||
|
try:
|
||||||
|
await self.event_emitter.emit_info(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to emit event: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class ReconNode(BaseNode):
|
||||||
|
"""
|
||||||
|
信息收集节点
|
||||||
|
|
||||||
|
输入: project_root, project_info, config
|
||||||
|
输出: tech_stack, entry_points, high_risk_areas, dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""执行信息收集"""
|
||||||
|
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用 Recon Agent
|
||||||
|
result = await self.agent.run({
|
||||||
|
"project_info": state["project_info"],
|
||||||
|
"config": state["config"],
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.success and result.data:
|
||||||
|
data = result.data
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_complete",
|
||||||
|
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"tech_stack": data.get("tech_stack", {}),
|
||||||
|
"entry_points": data.get("entry_points", []),
|
||||||
|
"high_risk_areas": data.get("high_risk_areas", []),
|
||||||
|
"dependencies": data.get("dependencies", {}),
|
||||||
|
"current_phase": "recon_complete",
|
||||||
|
"findings": data.get("initial_findings", []), # 初步发现
|
||||||
|
"events": [{
|
||||||
|
"type": "recon_complete",
|
||||||
|
"data": {
|
||||||
|
"entry_points_count": len(data.get("entry_points", [])),
|
||||||
|
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": result.error or "Recon failed",
|
||||||
|
"current_phase": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Recon node failed: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"current_phase": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisNode(BaseNode):
|
||||||
|
"""
|
||||||
|
漏洞分析节点
|
||||||
|
|
||||||
|
输入: tech_stack, entry_points, high_risk_areas, previous findings
|
||||||
|
输出: findings (累加), should_continue_analysis
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""执行漏洞分析"""
|
||||||
|
iteration = state.get("iteration", 0) + 1
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_start",
|
||||||
|
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建分析输入
|
||||||
|
analysis_input = {
|
||||||
|
"phase_name": "analysis",
|
||||||
|
"project_info": state["project_info"],
|
||||||
|
"config": state["config"],
|
||||||
|
"plan": {
|
||||||
|
"high_risk_areas": state.get("high_risk_areas", []),
|
||||||
|
},
|
||||||
|
"previous_results": {
|
||||||
|
"recon": {
|
||||||
|
"data": {
|
||||||
|
"tech_stack": state.get("tech_stack", {}),
|
||||||
|
"entry_points": state.get("entry_points", []),
|
||||||
|
"high_risk_areas": state.get("high_risk_areas", []),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 调用 Analysis Agent
|
||||||
|
result = await self.agent.run(analysis_input)
|
||||||
|
|
||||||
|
if result.success and result.data:
|
||||||
|
new_findings = result.data.get("findings", [])
|
||||||
|
|
||||||
|
# 判断是否需要继续分析
|
||||||
|
# 如果这一轮发现了很多问题,可能还有更多
|
||||||
|
should_continue = (
|
||||||
|
len(new_findings) >= 5 and
|
||||||
|
iteration < state.get("max_iterations", 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_complete",
|
||||||
|
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"findings": new_findings, # 会自动累加
|
||||||
|
"iteration": iteration,
|
||||||
|
"should_continue_analysis": should_continue,
|
||||||
|
"current_phase": "analysis_complete",
|
||||||
|
"events": [{
|
||||||
|
"type": "analysis_iteration",
|
||||||
|
"data": {
|
||||||
|
"iteration": iteration,
|
||||||
|
"findings_count": len(new_findings),
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"iteration": iteration,
|
||||||
|
"should_continue_analysis": False,
|
||||||
|
"current_phase": "analysis_complete",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Analysis node failed: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"should_continue_analysis": False,
|
||||||
|
"current_phase": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationNode(BaseNode):
|
||||||
|
"""
|
||||||
|
漏洞验证节点
|
||||||
|
|
||||||
|
输入: findings
|
||||||
|
输出: verified_findings, false_positives
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""执行漏洞验证"""
|
||||||
|
findings = state.get("findings", [])
|
||||||
|
|
||||||
|
if not findings:
|
||||||
|
return {
|
||||||
|
"verified_findings": [],
|
||||||
|
"false_positives": [],
|
||||||
|
"current_phase": "verification_complete",
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_start",
|
||||||
|
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建验证输入
|
||||||
|
verification_input = {
|
||||||
|
"previous_results": {
|
||||||
|
"analysis": {
|
||||||
|
"data": {
|
||||||
|
"findings": findings,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"config": state["config"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 调用 Verification Agent
|
||||||
|
result = await self.agent.run(verification_input)
|
||||||
|
|
||||||
|
if result.success and result.data:
|
||||||
|
verified = [f for f in result.data.get("findings", []) if f.get("is_verified")]
|
||||||
|
false_pos = [f["id"] for f in result.data.get("findings", [])
|
||||||
|
if f.get("verdict") == "false_positive"]
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_complete",
|
||||||
|
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"verified_findings": verified,
|
||||||
|
"false_positives": false_pos,
|
||||||
|
"current_phase": "verification_complete",
|
||||||
|
"events": [{
|
||||||
|
"type": "verification_complete",
|
||||||
|
"data": {
|
||||||
|
"verified_count": len(verified),
|
||||||
|
"false_positive_count": len(false_pos),
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"verified_findings": [],
|
||||||
|
"false_positives": [],
|
||||||
|
"current_phase": "verification_complete",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Verification node failed: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"current_phase": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ReportNode(BaseNode):
|
||||||
|
"""
|
||||||
|
报告生成节点
|
||||||
|
|
||||||
|
输入: all state
|
||||||
|
输出: summary, security_score
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""生成审计报告"""
|
||||||
|
await self.emit_event("phase_start", "📊 生成审计报告")
|
||||||
|
|
||||||
|
try:
|
||||||
|
findings = state.get("findings", [])
|
||||||
|
verified = state.get("verified_findings", [])
|
||||||
|
false_positives = state.get("false_positives", [])
|
||||||
|
|
||||||
|
# 统计漏洞分布
|
||||||
|
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||||
|
type_counts = {}
|
||||||
|
|
||||||
|
for finding in findings:
|
||||||
|
sev = finding.get("severity", "medium")
|
||||||
|
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||||
|
|
||||||
|
vtype = finding.get("vulnerability_type", "other")
|
||||||
|
type_counts[vtype] = type_counts.get(vtype, 0) + 1
|
||||||
|
|
||||||
|
# 计算安全评分
|
||||||
|
base_score = 100
|
||||||
|
deductions = (
|
||||||
|
severity_counts["critical"] * 25 +
|
||||||
|
severity_counts["high"] * 15 +
|
||||||
|
severity_counts["medium"] * 8 +
|
||||||
|
severity_counts["low"] * 3
|
||||||
|
)
|
||||||
|
security_score = max(0, base_score - deductions)
|
||||||
|
|
||||||
|
# 生成摘要
|
||||||
|
summary = {
|
||||||
|
"total_findings": len(findings),
|
||||||
|
"verified_count": len(verified),
|
||||||
|
"false_positive_count": len(false_positives),
|
||||||
|
"severity_distribution": severity_counts,
|
||||||
|
"vulnerability_types": type_counts,
|
||||||
|
"tech_stack": state.get("tech_stack", {}),
|
||||||
|
"entry_points_analyzed": len(state.get("entry_points", [])),
|
||||||
|
"high_risk_areas": state.get("high_risk_areas", []),
|
||||||
|
"iterations": state.get("iteration", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"phase_complete",
|
||||||
|
f"✅ 报告生成完成: 安全评分 {security_score}/100"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"summary": summary,
|
||||||
|
"security_score": security_score,
|
||||||
|
"current_phase": "complete",
|
||||||
|
"events": [{
|
||||||
|
"type": "audit_complete",
|
||||||
|
"data": {
|
||||||
|
"security_score": security_score,
|
||||||
|
"total_findings": len(findings),
|
||||||
|
"verified_count": len(verified),
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Report node failed: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"current_phase": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HumanReviewNode(BaseNode):
|
||||||
|
"""
|
||||||
|
人工审核节点
|
||||||
|
|
||||||
|
在此节点暂停,等待人工反馈
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
人工审核节点
|
||||||
|
|
||||||
|
这个节点会被 interrupt_before 暂停
|
||||||
|
用户可以:
|
||||||
|
1. 确认发现
|
||||||
|
2. 标记误报
|
||||||
|
3. 请求重新分析
|
||||||
|
"""
|
||||||
|
await self.emit_event(
|
||||||
|
"human_review",
|
||||||
|
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回当前状态,不做修改
|
||||||
|
# 人工反馈会通过 update_state 传入
|
||||||
|
return {
|
||||||
|
"current_phase": "human_review",
|
||||||
|
"messages": [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "等待人工审核",
|
||||||
|
"findings_for_review": state.get("verified_findings", []),
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,621 @@
|
||||||
|
"""
|
||||||
|
DeepAudit LangGraph Runner
|
||||||
|
基于 LangGraph 的 Agent 审计执行器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
|
from app.models.agent_task import (
|
||||||
|
AgentTask, AgentEvent, AgentFinding,
|
||||||
|
AgentTaskStatus, AgentTaskPhase, AgentEventType,
|
||||||
|
VulnerabilitySeverity, VulnerabilityType, FindingStatus,
|
||||||
|
)
|
||||||
|
from app.services.agent.event_manager import EventManager, AgentEventEmitter
|
||||||
|
from app.services.agent.tools import (
|
||||||
|
RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool,
|
||||||
|
PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool,
|
||||||
|
FileReadTool, FileSearchTool, ListFilesTool,
|
||||||
|
SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager,
|
||||||
|
SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool,
|
||||||
|
TruffleHogTool, OSVScannerTool,
|
||||||
|
)
|
||||||
|
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
from .audit_graph import AuditState, create_audit_graph
|
||||||
|
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMService:
|
||||||
|
"""LLM 服务封装"""
|
||||||
|
|
||||||
|
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
|
self.model = model or settings.DEFAULT_LLM_MODEL
|
||||||
|
self.api_key = api_key or settings.LLM_API_KEY
|
||||||
|
|
||||||
|
async def chat_completion_raw(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
temperature: float = 0.1,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""调用 LLM 生成响应"""
|
||||||
|
try:
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
api_key=self.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
} if response.usage else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM call failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunner:
|
||||||
|
"""
|
||||||
|
DeepAudit LangGraph Agent Runner
|
||||||
|
|
||||||
|
基于 LangGraph 状态图的审计执行器
|
||||||
|
|
||||||
|
工作流:
|
||||||
|
START → Recon → Analysis ⟲ → Verification → Report → END
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
task: AgentTask,
|
||||||
|
project_root: str,
|
||||||
|
):
|
||||||
|
self.db = db
|
||||||
|
self.task = task
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
# 事件管理
|
||||||
|
self.event_manager = EventManager(db, task.id)
|
||||||
|
self.event_emitter = AgentEventEmitter(self.event_manager)
|
||||||
|
|
||||||
|
# LLM 服务
|
||||||
|
self.llm_service = LLMService()
|
||||||
|
|
||||||
|
# 工具集
|
||||||
|
self.tools: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# RAG 组件
|
||||||
|
self.retriever: Optional[CodeRetriever] = None
|
||||||
|
self.indexer: Optional[CodeIndexer] = None
|
||||||
|
|
||||||
|
# 沙箱
|
||||||
|
self.sandbox_manager: Optional[SandboxManager] = None
|
||||||
|
|
||||||
|
# LangGraph
|
||||||
|
self.graph: Optional[StateGraph] = None
|
||||||
|
self.checkpointer = MemorySaver()
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化 Runner"""
|
||||||
|
await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...")
|
||||||
|
|
||||||
|
# 1. 初始化 RAG 系统
|
||||||
|
await self._initialize_rag()
|
||||||
|
|
||||||
|
# 2. 初始化工具
|
||||||
|
await self._initialize_tools()
|
||||||
|
|
||||||
|
# 3. 构建 LangGraph
|
||||||
|
await self._build_graph()
|
||||||
|
|
||||||
|
await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成")
|
||||||
|
|
||||||
|
async def _initialize_rag(self):
|
||||||
|
"""初始化 RAG 系统"""
|
||||||
|
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_service = EmbeddingService(
|
||||||
|
provider=settings.EMBEDDING_PROVIDER,
|
||||||
|
model=settings.EMBEDDING_MODEL,
|
||||||
|
api_key=settings.LLM_API_KEY,
|
||||||
|
base_url=settings.LLM_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.indexer = CodeIndexer(
|
||||||
|
embedding_service=embedding_service,
|
||||||
|
vector_db_path=settings.VECTOR_DB_PATH,
|
||||||
|
collection_name=f"project_{self.task.project_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.retriever = CodeRetriever(
|
||||||
|
embedding_service=embedding_service,
|
||||||
|
vector_db_path=settings.VECTOR_DB_PATH,
|
||||||
|
collection_name=f"project_{self.task.project_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"RAG initialization failed: {e}")
|
||||||
|
await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}")
|
||||||
|
|
||||||
|
async def _initialize_tools(self):
|
||||||
|
"""初始化工具集"""
|
||||||
|
await self.event_emitter.emit_info("🔧 初始化 Agent 工具集...")
|
||||||
|
|
||||||
|
# 文件工具
|
||||||
|
self.tools["read_file"] = FileReadTool(self.project_root)
|
||||||
|
self.tools["search_code"] = FileSearchTool(self.project_root)
|
||||||
|
self.tools["list_files"] = ListFilesTool(self.project_root)
|
||||||
|
|
||||||
|
# RAG 工具
|
||||||
|
if self.retriever:
|
||||||
|
self.tools["rag_query"] = RAGQueryTool(self.retriever)
|
||||||
|
self.tools["security_search"] = SecurityCodeSearchTool(self.retriever)
|
||||||
|
self.tools["function_context"] = FunctionContextTool(self.retriever)
|
||||||
|
|
||||||
|
# 分析工具
|
||||||
|
self.tools["pattern_match"] = PatternMatchTool(self.project_root)
|
||||||
|
self.tools["code_analysis"] = CodeAnalysisTool(self.llm_service)
|
||||||
|
self.tools["dataflow_analysis"] = DataFlowAnalysisTool(self.llm_service)
|
||||||
|
self.tools["vulnerability_validation"] = VulnerabilityValidationTool(self.llm_service)
|
||||||
|
|
||||||
|
# 外部安全工具
|
||||||
|
self.tools["semgrep_scan"] = SemgrepTool(self.project_root)
|
||||||
|
self.tools["bandit_scan"] = BanditTool(self.project_root)
|
||||||
|
self.tools["gitleaks_scan"] = GitleaksTool(self.project_root)
|
||||||
|
self.tools["trufflehog_scan"] = TruffleHogTool(self.project_root)
|
||||||
|
self.tools["npm_audit"] = NpmAuditTool(self.project_root)
|
||||||
|
self.tools["safety_scan"] = SafetyTool(self.project_root)
|
||||||
|
self.tools["osv_scan"] = OSVScannerTool(self.project_root)
|
||||||
|
|
||||||
|
# 沙箱工具
|
||||||
|
try:
|
||||||
|
self.sandbox_manager = SandboxManager(
|
||||||
|
image=settings.SANDBOX_IMAGE,
|
||||||
|
memory_limit=settings.SANDBOX_MEMORY_LIMIT,
|
||||||
|
cpu_limit=settings.SANDBOX_CPU_LIMIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
|
||||||
|
self.tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
|
||||||
|
self.tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Sandbox initialization failed: {e}")
|
||||||
|
|
||||||
|
await self.event_emitter.emit_info(f"✅ 已加载 {len(self.tools)} 个工具")
|
||||||
|
|
||||||
|
async def _build_graph(self):
|
||||||
|
"""构建 LangGraph 审计图"""
|
||||||
|
await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...")
|
||||||
|
|
||||||
|
# 导入 Agent
|
||||||
|
from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent
|
||||||
|
|
||||||
|
# 创建 Agent 实例
|
||||||
|
recon_agent = ReconAgent(
|
||||||
|
llm_service=self.llm_service,
|
||||||
|
tools=self.tools,
|
||||||
|
event_emitter=self.event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis_agent = AnalysisAgent(
|
||||||
|
llm_service=self.llm_service,
|
||||||
|
tools=self.tools,
|
||||||
|
event_emitter=self.event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
verification_agent = VerificationAgent(
|
||||||
|
llm_service=self.llm_service,
|
||||||
|
tools=self.tools,
|
||||||
|
event_emitter=self.event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建节点
|
||||||
|
recon_node = ReconNode(recon_agent, self.event_emitter)
|
||||||
|
analysis_node = AnalysisNode(analysis_agent, self.event_emitter)
|
||||||
|
verification_node = VerificationNode(verification_agent, self.event_emitter)
|
||||||
|
report_node = ReportNode(None, self.event_emitter)
|
||||||
|
|
||||||
|
# 构建图
|
||||||
|
self.graph = create_audit_graph(
|
||||||
|
recon_node=recon_node,
|
||||||
|
analysis_node=analysis_node,
|
||||||
|
verification_node=verification_node,
|
||||||
|
report_node=report_node,
|
||||||
|
checkpointer=self.checkpointer,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成")
|
||||||
|
|
||||||
|
async def run(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 LangGraph 审计
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最终状态
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 初始化
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
await self._update_task_status(AgentTaskStatus.RUNNING)
|
||||||
|
|
||||||
|
# 1. 索引代码
|
||||||
|
await self._index_code()
|
||||||
|
|
||||||
|
if self._cancelled:
|
||||||
|
return {"success": False, "error": "任务已取消"}
|
||||||
|
|
||||||
|
# 2. 收集项目信息
|
||||||
|
project_info = await self._collect_project_info()
|
||||||
|
|
||||||
|
# 3. 构建初始状态
|
||||||
|
initial_state: AuditState = {
|
||||||
|
"project_root": self.project_root,
|
||||||
|
"project_info": project_info,
|
||||||
|
"config": self.task.config or {},
|
||||||
|
"task_id": self.task.id,
|
||||||
|
"tech_stack": {},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"findings": [],
|
||||||
|
"verified_findings": [],
|
||||||
|
"false_positives": [],
|
||||||
|
"current_phase": "start",
|
||||||
|
"iteration": 0,
|
||||||
|
"max_iterations": (self.task.config or {}).get("max_iterations", 3),
|
||||||
|
"should_continue_analysis": False,
|
||||||
|
"messages": [],
|
||||||
|
"events": [],
|
||||||
|
"summary": None,
|
||||||
|
"security_score": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 执行 LangGraph
|
||||||
|
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
|
||||||
|
|
||||||
|
run_config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": self.task.id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final_state = None
|
||||||
|
|
||||||
|
# 流式执行并发射事件
|
||||||
|
async for event in self.graph.astream(initial_state, config=run_config):
|
||||||
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 处理每个节点的输出
|
||||||
|
for node_name, node_output in event.items():
|
||||||
|
await self._handle_node_output(node_name, node_output)
|
||||||
|
|
||||||
|
# 更新阶段
|
||||||
|
phase_map = {
|
||||||
|
"recon": AgentTaskPhase.RECONNAISSANCE,
|
||||||
|
"analysis": AgentTaskPhase.ANALYSIS,
|
||||||
|
"verification": AgentTaskPhase.VERIFICATION,
|
||||||
|
"report": AgentTaskPhase.REPORTING,
|
||||||
|
}
|
||||||
|
if node_name in phase_map:
|
||||||
|
await self._update_task_phase(phase_map[node_name])
|
||||||
|
|
||||||
|
final_state = node_output
|
||||||
|
|
||||||
|
# 5. 获取最终状态
|
||||||
|
if not final_state:
|
||||||
|
graph_state = self.graph.get_state(run_config)
|
||||||
|
final_state = graph_state.values if graph_state else {}
|
||||||
|
|
||||||
|
# 6. 保存发现
|
||||||
|
findings = final_state.get("findings", [])
|
||||||
|
await self._save_findings(findings)
|
||||||
|
|
||||||
|
# 7. 更新任务摘要
|
||||||
|
summary = final_state.get("summary", {})
|
||||||
|
security_score = final_state.get("security_score", 100)
|
||||||
|
|
||||||
|
await self._update_task_summary(
|
||||||
|
total_findings=len(findings),
|
||||||
|
verified_count=len(final_state.get("verified_findings", [])),
|
||||||
|
security_score=security_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. 完成
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
await self._update_task_status(AgentTaskStatus.COMPLETED)
|
||||||
|
await self.event_emitter.emit_task_complete(
|
||||||
|
findings_count=len(findings),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"findings": findings,
|
||||||
|
"verified_findings": final_state.get("verified_findings", []),
|
||||||
|
"summary": summary,
|
||||||
|
"security_score": security_score,
|
||||||
|
},
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await self._update_task_status(AgentTaskStatus.CANCELLED)
|
||||||
|
return {"success": False, "error": "任务已取消"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LangGraph run failed: {e}", exc_info=True)
|
||||||
|
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
|
||||||
|
await self.event_emitter.emit_error(str(e))
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
|
||||||
|
"""处理节点输出"""
|
||||||
|
# 发射节点事件
|
||||||
|
events = output.get("events", [])
|
||||||
|
for evt in events:
|
||||||
|
await self.event_emitter.emit_info(
|
||||||
|
f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理新发现
|
||||||
|
if node_name == "analysis":
|
||||||
|
new_findings = output.get("findings", [])
|
||||||
|
if new_findings:
|
||||||
|
for finding in new_findings[:5]: # 限制事件数量
|
||||||
|
await self.event_emitter.emit_finding(
|
||||||
|
title=finding.get("title", "Unknown"),
|
||||||
|
severity=finding.get("severity", "medium"),
|
||||||
|
file_path=finding.get("file_path"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理验证结果
|
||||||
|
if node_name == "verification":
|
||||||
|
verified = output.get("verified_findings", [])
|
||||||
|
for v in verified[:5]:
|
||||||
|
await self.event_emitter.emit_info(
|
||||||
|
f"✅ 已验证: {v.get('title', 'Unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理错误
|
||||||
|
if output.get("error"):
|
||||||
|
await self.event_emitter.emit_error(output["error"])
|
||||||
|
|
||||||
|
async def _index_code(self):
|
||||||
|
"""索引代码"""
|
||||||
|
if not self.indexer:
|
||||||
|
await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._update_task_phase(AgentTaskPhase.INDEXING)
|
||||||
|
await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for progress in self.indexer.index_directory(self.project_root):
|
||||||
|
if self._cancelled:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.event_emitter.emit_progress(
|
||||||
|
progress.processed / max(progress.total, 1) * 100,
|
||||||
|
f"正在索引: {progress.current_file or 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Code indexing failed: {e}")
|
||||||
|
await self.event_emitter.emit_warning(f"代码索引失败: {e}")
|
||||||
|
|
||||||
|
async def _collect_project_info(self) -> Dict[str, Any]:
|
||||||
|
"""收集项目信息"""
|
||||||
|
info = {
|
||||||
|
"name": self.task.project.name if self.task.project else "unknown",
|
||||||
|
"root": self.project_root,
|
||||||
|
"languages": [],
|
||||||
|
"file_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
exclude_dirs = {
|
||||||
|
"node_modules", "__pycache__", ".git", "venv", ".venv",
|
||||||
|
"build", "dist", "target", ".idea", ".vscode",
|
||||||
|
}
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(self.project_root):
|
||||||
|
dirs[:] = [d for d in dirs if d not in exclude_dirs]
|
||||||
|
info["file_count"] += len(files)
|
||||||
|
|
||||||
|
lang_map = {
|
||||||
|
".py": "Python", ".js": "JavaScript", ".ts": "TypeScript",
|
||||||
|
".java": "Java", ".go": "Go", ".php": "PHP",
|
||||||
|
".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++",
|
||||||
|
}
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
ext = os.path.splitext(f)[1].lower()
|
||||||
|
if ext in lang_map and lang_map[ext] not in info["languages"]:
|
||||||
|
info["languages"].append(lang_map[ext])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to collect project info: {e}")
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
async def _save_findings(self, findings: List[Dict]):
|
||||||
|
"""保存发现到数据库"""
|
||||||
|
severity_map = {
|
||||||
|
"critical": VulnerabilitySeverity.CRITICAL,
|
||||||
|
"high": VulnerabilitySeverity.HIGH,
|
||||||
|
"medium": VulnerabilitySeverity.MEDIUM,
|
||||||
|
"low": VulnerabilitySeverity.LOW,
|
||||||
|
"info": VulnerabilitySeverity.INFO,
|
||||||
|
}
|
||||||
|
|
||||||
|
type_map = {
|
||||||
|
"sql_injection": VulnerabilityType.SQL_INJECTION,
|
||||||
|
"xss": VulnerabilityType.XSS,
|
||||||
|
"command_injection": VulnerabilityType.COMMAND_INJECTION,
|
||||||
|
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
|
||||||
|
"ssrf": VulnerabilityType.SSRF,
|
||||||
|
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
|
||||||
|
"deserialization": VulnerabilityType.INSECURE_DESERIALIZATION,
|
||||||
|
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
|
||||||
|
}
|
||||||
|
|
||||||
|
for finding in findings:
|
||||||
|
try:
|
||||||
|
db_finding = AgentFinding(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
task_id=self.task.id,
|
||||||
|
vulnerability_type=type_map.get(
|
||||||
|
finding.get("vulnerability_type", "other"),
|
||||||
|
VulnerabilityType.OTHER
|
||||||
|
),
|
||||||
|
severity=severity_map.get(
|
||||||
|
finding.get("severity", "medium"),
|
||||||
|
VulnerabilitySeverity.MEDIUM
|
||||||
|
),
|
||||||
|
title=finding.get("title", "Unknown"),
|
||||||
|
description=finding.get("description", ""),
|
||||||
|
file_path=finding.get("file_path"),
|
||||||
|
line_start=finding.get("line_start"),
|
||||||
|
line_end=finding.get("line_end"),
|
||||||
|
code_snippet=finding.get("code_snippet"),
|
||||||
|
source=finding.get("source"),
|
||||||
|
sink=finding.get("sink"),
|
||||||
|
suggestion=finding.get("suggestion") or finding.get("recommendation"),
|
||||||
|
is_verified=finding.get("is_verified", False),
|
||||||
|
confidence=finding.get("confidence", 0.5),
|
||||||
|
poc=finding.get("poc"),
|
||||||
|
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(db_finding)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save finding: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to commit findings: {e}")
|
||||||
|
await self.db.rollback()
|
||||||
|
|
||||||
|
async def _update_task_status(
|
||||||
|
self,
|
||||||
|
status: AgentTaskStatus,
|
||||||
|
error: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""更新任务状态"""
|
||||||
|
self.task.status = status
|
||||||
|
|
||||||
|
if status == AgentTaskStatus.RUNNING:
|
||||||
|
self.task.started_at = datetime.now(timezone.utc)
|
||||||
|
elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||||
|
self.task.finished_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
self.task.error_message = error
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update task status: {e}")
|
||||||
|
|
||||||
|
async def _update_task_phase(self, phase: AgentTaskPhase):
|
||||||
|
"""更新任务阶段"""
|
||||||
|
self.task.current_phase = phase
|
||||||
|
try:
|
||||||
|
await self.db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update task phase: {e}")
|
||||||
|
|
||||||
|
async def _update_task_summary(
|
||||||
|
self,
|
||||||
|
total_findings: int,
|
||||||
|
verified_count: int,
|
||||||
|
security_score: int,
|
||||||
|
):
|
||||||
|
"""更新任务摘要"""
|
||||||
|
self.task.total_findings = total_findings
|
||||||
|
self.task.verified_findings = verified_count
|
||||||
|
self.task.security_score = security_score
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update task summary: {e}")
|
||||||
|
|
||||||
|
async def _cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
try:
|
||||||
|
if self.sandbox_manager:
|
||||||
|
await self.sandbox_manager.cleanup()
|
||||||
|
await self.event_manager.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
"""取消任务"""
|
||||||
|
self._cancelled = True
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷函数
|
||||||
|
async def run_agent_task(
|
||||||
|
db: AsyncSession,
|
||||||
|
task: AgentTask,
|
||||||
|
project_root: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
运行 Agent 审计任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
task: Agent 任务
|
||||||
|
project_root: 项目根目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
审计结果
|
||||||
|
"""
|
||||||
|
runner = AgentRunner(db, task, project_root)
|
||||||
|
return await runner.run()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""
|
||||||
|
Agent 提示词模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .system_prompts import (
|
||||||
|
ORCHESTRATOR_SYSTEM_PROMPT,
|
||||||
|
ANALYSIS_SYSTEM_PROMPT,
|
||||||
|
VERIFICATION_SYSTEM_PROMPT,
|
||||||
|
PLANNING_PROMPT,
|
||||||
|
REPORTING_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ORCHESTRATOR_SYSTEM_PROMPT",
|
||||||
|
"ANALYSIS_SYSTEM_PROMPT",
|
||||||
|
"VERIFICATION_SYSTEM_PROMPT",
|
||||||
|
"PLANNING_PROMPT",
|
||||||
|
"REPORTING_PROMPT",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,170 @@
|
||||||
|
"""
|
||||||
|
Agent 系统提示词
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 编排 Agent 系统提示词
|
||||||
|
ORCHESTRATOR_SYSTEM_PROMPT = """你是一个专业的代码安全审计 Agent,负责自主分析代码并发现安全漏洞。
|
||||||
|
|
||||||
|
## 你的职责
|
||||||
|
1. 分析项目代码,制定审计计划
|
||||||
|
2. 使用工具深入分析代码
|
||||||
|
3. 发现并验证安全漏洞
|
||||||
|
4. 生成详细的漏洞报告
|
||||||
|
|
||||||
|
## 审计流程
|
||||||
|
1. **规划阶段**: 分析项目结构,识别高风险区域,制定审计计划
|
||||||
|
2. **索引阶段**: 等待代码索引完成
|
||||||
|
3. **分析阶段**: 使用工具进行深度代码分析
|
||||||
|
4. **验证阶段**: 在沙箱中验证发现的漏洞
|
||||||
|
5. **报告阶段**: 整理发现,生成报告
|
||||||
|
|
||||||
|
## 重点关注的漏洞类型
|
||||||
|
- SQL 注入(包括 ORM 注入)
|
||||||
|
- XSS 跨站脚本(反射型、存储型、DOM型)
|
||||||
|
- 命令注入和代码注入
|
||||||
|
- 路径遍历和任意文件访问
|
||||||
|
- SSRF 服务端请求伪造
|
||||||
|
- XXE XML 外部实体注入
|
||||||
|
- 不安全的反序列化
|
||||||
|
- 认证和授权绕过
|
||||||
|
- 敏感信息泄露(硬编码密钥、日志泄露)
|
||||||
|
- 业务逻辑漏洞
|
||||||
|
- IDOR 不安全的直接对象引用
|
||||||
|
|
||||||
|
## 分析方法
|
||||||
|
1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码
|
||||||
|
2. **语义搜索**: 使用 rag_query 查找相关上下文
|
||||||
|
3. **深度分析**: 对可疑代码使用 code_analysis 深入分析
|
||||||
|
4. **数据流追踪**: 追踪用户输入到危险函数的路径
|
||||||
|
5. **漏洞验证**: 在沙箱中验证发现的漏洞
|
||||||
|
|
||||||
|
## 工作原则
|
||||||
|
- 系统性: 不遗漏任何可能的攻击面
|
||||||
|
- 精准性: 减少误报,每个发现都要有充分证据
|
||||||
|
- 深入性: 不只看表面,要理解代码逻辑
|
||||||
|
- 可操作性: 提供具体的修复建议
|
||||||
|
|
||||||
|
## 输出要求
|
||||||
|
发现漏洞时,提供:
|
||||||
|
- 漏洞类型和严重程度
|
||||||
|
- 具体位置(文件、行号)
|
||||||
|
- 漏洞描述和成因
|
||||||
|
- 利用方式和影响
|
||||||
|
- 修复建议和示例代码
|
||||||
|
|
||||||
|
请开始审计工作,使用可用的工具进行分析。"""
|
||||||
|
|
||||||
|
# 分析 Agent 系统提示词
|
||||||
|
ANALYSIS_SYSTEM_PROMPT = """你是一个专注于代码漏洞分析的安全专家。
|
||||||
|
|
||||||
|
## 你的任务
|
||||||
|
深入分析代码,发现安全漏洞。你需要:
|
||||||
|
1. 识别危险的代码模式
|
||||||
|
2. 追踪数据流(从用户输入到危险函数)
|
||||||
|
3. 判断漏洞是否可利用
|
||||||
|
4. 评估漏洞的严重程度
|
||||||
|
|
||||||
|
## 可用工具
|
||||||
|
- rag_query: 语义搜索相关代码
|
||||||
|
- pattern_match: 快速模式匹配
|
||||||
|
- code_analysis: LLM 深度分析
|
||||||
|
- read_file: 读取文件内容
|
||||||
|
- search_code: 关键字搜索
|
||||||
|
- dataflow_analysis: 数据流分析
|
||||||
|
- vulnerability_validation: 漏洞验证
|
||||||
|
|
||||||
|
## 分析策略
|
||||||
|
1. 先全局后局部:先了解整体架构,再深入细节
|
||||||
|
2. 先快后深:先快速扫描,再深入可疑点
|
||||||
|
3. 追踪数据流:用户输入 → 处理逻辑 → 危险函数
|
||||||
|
4. 验证每个发现:确保不是误报
|
||||||
|
|
||||||
|
## 严重程度评估标准
|
||||||
|
- **Critical**: 可直接导致系统被控制或大规模数据泄露
|
||||||
|
- **High**: 可导致敏感数据泄露或重要功能被绕过
|
||||||
|
- **Medium**: 可导致部分数据泄露或需要特定条件利用
|
||||||
|
- **Low**: 影响有限或利用条件苛刻
|
||||||
|
|
||||||
|
请开始分析,专注于发现真实的安全漏洞。"""
|
||||||
|
|
||||||
|
# 验证 Agent 系统提示词
|
||||||
|
VERIFICATION_SYSTEM_PROMPT = """你是一个专注于漏洞验证的安全专家。
|
||||||
|
|
||||||
|
## 你的任务
|
||||||
|
验证发现的漏洞是否真实存在,判断是否为误报。
|
||||||
|
|
||||||
|
## 验证方法
|
||||||
|
1. **代码审查**: 仔细分析漏洞代码和上下文
|
||||||
|
2. **构造 Payload**: 设计能触发漏洞的输入
|
||||||
|
3. **沙箱测试**: 在隔离环境中测试漏洞
|
||||||
|
4. **分析结果**: 判断漏洞是否可利用
|
||||||
|
|
||||||
|
## 可用工具
|
||||||
|
- sandbox_exec: 在沙箱中执行命令
|
||||||
|
- sandbox_http: 发送 HTTP 请求
|
||||||
|
- verify_vulnerability: 自动验证漏洞
|
||||||
|
- vulnerability_validation: 深度验证分析
|
||||||
|
|
||||||
|
## 验证原则
|
||||||
|
- 安全第一:所有测试在沙箱中进行
|
||||||
|
- 证据充分:验证结果要有明确证据
|
||||||
|
- 谨慎判断:不确定时标记为需要人工审核
|
||||||
|
|
||||||
|
## 输出要求
|
||||||
|
- 验证结果:确认/可能/误报
|
||||||
|
- 验证方法:使用的测试方法
|
||||||
|
- 证据:支持判断的具体证据
|
||||||
|
- PoC(如果确认):可复现的测试代码
|
||||||
|
|
||||||
|
请开始验证工作。"""
|
||||||
|
|
||||||
|
# 规划提示词
|
||||||
|
PLANNING_PROMPT = """基于以下项目信息,制定安全审计计划。
|
||||||
|
|
||||||
|
## 项目信息
|
||||||
|
- 名称: {project_name}
|
||||||
|
- 语言: {languages}
|
||||||
|
- 文件数量: {file_count}
|
||||||
|
- 目录结构: {directory_structure}
|
||||||
|
|
||||||
|
## 请输出审计计划
|
||||||
|
包含以下内容(JSON格式):
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"high_risk_areas": ["高风险目录/文件列表"],
|
||||||
|
"focus_vulnerabilities": ["重点关注的漏洞类型"],
|
||||||
|
"audit_order": ["审计顺序"],
|
||||||
|
"estimated_steps": "预计步骤数",
|
||||||
|
"special_attention": ["特别注意事项"]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 高风险区域识别原则
|
||||||
|
1. 用户认证和授权相关代码
|
||||||
|
2. 数据库操作和 ORM 使用
|
||||||
|
3. 文件上传和下载功能
|
||||||
|
4. API 接口和输入处理
|
||||||
|
5. 第三方服务调用
|
||||||
|
6. 配置文件和环境变量
|
||||||
|
7. 加密和密钥管理"""
|
||||||
|
|
||||||
|
# 报告生成提示词
|
||||||
|
REPORTING_PROMPT = """基于审计发现,生成安全审计报告摘要。
|
||||||
|
|
||||||
|
## 审计发现
|
||||||
|
{findings}
|
||||||
|
|
||||||
|
## 统计信息
|
||||||
|
- 总发现数: {total_findings}
|
||||||
|
- 已验证: {verified_count}
|
||||||
|
- 严重程度分布: {severity_distribution}
|
||||||
|
|
||||||
|
## 请输出报告摘要
|
||||||
|
包含以下内容:
|
||||||
|
1. 整体安全评估
|
||||||
|
2. 主要风险点
|
||||||
|
3. 优先修复建议
|
||||||
|
4. 安全改进建议
|
||||||
|
|
||||||
|
请用简洁专业的语言描述。"""
|
||||||
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""
|
||||||
|
Agent 工具集
|
||||||
|
提供 LangChain Agent 使用的各种工具
|
||||||
|
包括内置工具和外部安全工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
from .rag_tool import RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool
|
||||||
|
from .pattern_tool import PatternMatchTool
|
||||||
|
from .code_analysis_tool import CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool
|
||||||
|
from .file_tool import FileReadTool, FileSearchTool, ListFilesTool
|
||||||
|
from .sandbox_tool import SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager
|
||||||
|
|
||||||
|
# 外部安全工具
|
||||||
|
from .external_tools import (
|
||||||
|
SemgrepTool,
|
||||||
|
BanditTool,
|
||||||
|
GitleaksTool,
|
||||||
|
NpmAuditTool,
|
||||||
|
SafetyTool,
|
||||||
|
TruffleHogTool,
|
||||||
|
OSVScannerTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 基础
|
||||||
|
"AgentTool",
|
||||||
|
"ToolResult",
|
||||||
|
|
||||||
|
# RAG 工具
|
||||||
|
"RAGQueryTool",
|
||||||
|
"SecurityCodeSearchTool",
|
||||||
|
"FunctionContextTool",
|
||||||
|
|
||||||
|
# 代码分析
|
||||||
|
"PatternMatchTool",
|
||||||
|
"CodeAnalysisTool",
|
||||||
|
"DataFlowAnalysisTool",
|
||||||
|
"VulnerabilityValidationTool",
|
||||||
|
|
||||||
|
# 文件操作
|
||||||
|
"FileReadTool",
|
||||||
|
"FileSearchTool",
|
||||||
|
"ListFilesTool",
|
||||||
|
|
||||||
|
# 沙箱
|
||||||
|
"SandboxTool",
|
||||||
|
"SandboxHttpTool",
|
||||||
|
"VulnerabilityVerifyTool",
|
||||||
|
"SandboxManager",
|
||||||
|
|
||||||
|
# 外部安全工具
|
||||||
|
"SemgrepTool",
|
||||||
|
"BanditTool",
|
||||||
|
"GitleaksTool",
|
||||||
|
"NpmAuditTool",
|
||||||
|
"SafetyTool",
|
||||||
|
"TruffleHogTool",
|
||||||
|
"OSVScannerTool",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
"""
|
||||||
|
Agent 工具基类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional, Type
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""工具执行结果"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
duration_ms: int = 0
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error,
|
||||||
|
"duration_ms": self.duration_ms,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_string(self, max_length: int = 5000) -> str:
|
||||||
|
"""转换为字符串(用于 LLM 输出)"""
|
||||||
|
if not self.success:
|
||||||
|
return f"Error: {self.error}"
|
||||||
|
|
||||||
|
if isinstance(self.data, str):
|
||||||
|
result = self.data
|
||||||
|
elif isinstance(self.data, (dict, list)):
|
||||||
|
import json
|
||||||
|
result = json.dumps(self.data, ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
result = str(self.data)
|
||||||
|
|
||||||
|
if len(result) > max_length:
|
||||||
|
result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)"
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTool(ABC):
|
||||||
|
"""
|
||||||
|
Agent 工具基类
|
||||||
|
所有工具需要继承此类并实现必要的方法
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._call_count = 0
|
||||||
|
self._total_duration_ms = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""工具名称"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def description(self) -> str:
|
||||||
|
"""工具描述(用于 Agent 理解工具功能)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self) -> Optional[Type[BaseModel]]:
|
||||||
|
"""参数 Schema(Pydantic 模型)"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _execute(self, **kwargs) -> ToolResult:
|
||||||
|
"""执行工具(子类实现)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> ToolResult:
|
||||||
|
"""执行工具(带计时和日志)"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"Tool '{self.name}' executing with args: {kwargs}")
|
||||||
|
result = await self._execute(**kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool '{self.name}' error: {e}", exc_info=True)
|
||||||
|
result = ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
result.duration_ms = duration_ms
|
||||||
|
|
||||||
|
self._call_count += 1
|
||||||
|
self._total_duration_ms += duration_ms
|
||||||
|
|
||||||
|
logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_langchain_tool(self):
|
||||||
|
"""转换为 LangChain Tool"""
|
||||||
|
from langchain.tools import Tool, StructuredTool
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
def sync_wrapper(**kwargs):
|
||||||
|
"""同步包装器"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
import concurrent.futures
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(asyncio.run, self.execute(**kwargs))
|
||||||
|
result = future.result()
|
||||||
|
else:
|
||||||
|
result = asyncio.run(self.execute(**kwargs))
|
||||||
|
return result.to_string()
|
||||||
|
|
||||||
|
async def async_wrapper(**kwargs):
|
||||||
|
"""异步包装器"""
|
||||||
|
result = await self.execute(**kwargs)
|
||||||
|
return result.to_string()
|
||||||
|
|
||||||
|
if self.args_schema:
|
||||||
|
return StructuredTool(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
func=sync_wrapper,
|
||||||
|
coroutine=async_wrapper,
|
||||||
|
args_schema=self.args_schema,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Tool(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
func=lambda x: sync_wrapper(query=x),
|
||||||
|
coroutine=lambda x: async_wrapper(query=x),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> Dict[str, Any]:
|
||||||
|
"""工具使用统计"""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"call_count": self._call_count,
|
||||||
|
"total_duration_ms": self._total_duration_ms,
|
||||||
|
"avg_duration_ms": self._total_duration_ms // max(1, self._call_count),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,427 @@
|
||||||
|
"""
|
||||||
|
代码分析工具
|
||||||
|
使用 LLM 深度分析代码安全问题
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class CodeAnalysisInput(BaseModel):
|
||||||
|
"""代码分析输入"""
|
||||||
|
code: str = Field(description="要分析的代码内容")
|
||||||
|
file_path: str = Field(default="unknown", description="文件路径")
|
||||||
|
language: str = Field(default="python", description="编程语言")
|
||||||
|
focus: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="重点关注的漏洞类型,如 sql_injection, xss, command_injection"
|
||||||
|
)
|
||||||
|
context: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="额外的上下文信息,如相关的其他代码片段"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeAnalysisTool(AgentTool):
|
||||||
|
"""
|
||||||
|
代码分析工具
|
||||||
|
使用 LLM 对代码进行深度安全分析
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_service):
|
||||||
|
"""
|
||||||
|
初始化代码分析工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_service: LLM 服务实例
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.llm_service = llm_service
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "code_analysis"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """深度分析代码安全问题。
|
||||||
|
使用 LLM 对代码进行全面的安全审计,识别潜在漏洞。
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 对疑似有问题的代码进行深入分析
|
||||||
|
- 分析复杂的业务逻辑漏洞
|
||||||
|
- 追踪数据流和污点传播
|
||||||
|
- 生成详细的漏洞报告和修复建议
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- code: 要分析的代码
|
||||||
|
- file_path: 文件路径
|
||||||
|
- language: 编程语言
|
||||||
|
- focus: 可选,重点关注的漏洞类型
|
||||||
|
- context: 可选,额外的上下文代码
|
||||||
|
|
||||||
|
这个工具会消耗较多的 Token,建议在确认有疑似问题后使用。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return CodeAnalysisInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
file_path: str = "unknown",
|
||||||
|
language: str = "python",
|
||||||
|
focus: Optional[str] = None,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行代码分析"""
|
||||||
|
try:
|
||||||
|
# 构建分析结果
|
||||||
|
analysis = await self.llm_service.analyze_code(code, language)
|
||||||
|
|
||||||
|
issues = analysis.get("issues", [])
|
||||||
|
|
||||||
|
if not issues:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="代码分析完成,未发现明显的安全问题。\n\n"
|
||||||
|
f"质量评分: {analysis.get('quality_score', 'N/A')}\n"
|
||||||
|
f"文件: {file_path}",
|
||||||
|
metadata={
|
||||||
|
"file_path": file_path,
|
||||||
|
"issues_count": 0,
|
||||||
|
"quality_score": analysis.get("quality_score"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"🔍 代码分析结果 - {file_path}\n"]
|
||||||
|
output_parts.append(f"发现 {len(issues)} 个问题:\n")
|
||||||
|
|
||||||
|
for i, issue in enumerate(issues):
|
||||||
|
severity_icon = {
|
||||||
|
"critical": "🔴",
|
||||||
|
"high": "🟠",
|
||||||
|
"medium": "🟡",
|
||||||
|
"low": "🟢"
|
||||||
|
}.get(issue.get("severity", ""), "⚪")
|
||||||
|
|
||||||
|
output_parts.append(f"\n{severity_icon} 问题 {i+1}: {issue.get('title', 'Unknown')}")
|
||||||
|
output_parts.append(f" 类型: {issue.get('type', 'unknown')}")
|
||||||
|
output_parts.append(f" 严重程度: {issue.get('severity', 'unknown')}")
|
||||||
|
output_parts.append(f" 行号: {issue.get('line', 'N/A')}")
|
||||||
|
output_parts.append(f" 描述: {issue.get('description', '')}")
|
||||||
|
|
||||||
|
if issue.get("code_snippet"):
|
||||||
|
output_parts.append(f" 代码片段:\n ```\n {issue.get('code_snippet')}\n ```")
|
||||||
|
|
||||||
|
if issue.get("suggestion"):
|
||||||
|
output_parts.append(f" 修复建议: {issue.get('suggestion')}")
|
||||||
|
|
||||||
|
if issue.get("ai_explanation"):
|
||||||
|
output_parts.append(f" AI解释: {issue.get('ai_explanation')}")
|
||||||
|
|
||||||
|
output_parts.append(f"\n质量评分: {analysis.get('quality_score', 'N/A')}/100")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"file_path": file_path,
|
||||||
|
"issues_count": len(issues),
|
||||||
|
"quality_score": analysis.get("quality_score"),
|
||||||
|
"issues": issues,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"代码分析失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataFlowAnalysisInput(BaseModel):
|
||||||
|
"""数据流分析输入"""
|
||||||
|
source_code: str = Field(description="包含数据源的代码")
|
||||||
|
sink_code: Optional[str] = Field(default=None, description="包含数据汇的代码(如危险函数)")
|
||||||
|
variable_name: str = Field(description="要追踪的变量名")
|
||||||
|
file_path: str = Field(default="unknown", description="文件路径")
|
||||||
|
|
||||||
|
|
||||||
|
class DataFlowAnalysisTool(AgentTool):
|
||||||
|
"""
|
||||||
|
数据流分析工具
|
||||||
|
追踪变量从源到汇的数据流
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_service):
|
||||||
|
super().__init__()
|
||||||
|
self.llm_service = llm_service
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "dataflow_analysis"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """分析代码中的数据流,追踪变量从源(如用户输入)到汇(如危险函数)的路径。
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 追踪用户输入如何流向危险函数
|
||||||
|
- 分析变量是否经过净化处理
|
||||||
|
- 识别污点传播路径
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- source_code: 包含数据源的代码
|
||||||
|
- sink_code: 包含数据汇的代码(可选)
|
||||||
|
- variable_name: 要追踪的变量名
|
||||||
|
- file_path: 文件路径"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return DataFlowAnalysisInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
source_code: str,
|
||||||
|
variable_name: str,
|
||||||
|
sink_code: Optional[str] = None,
|
||||||
|
file_path: str = "unknown",
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行数据流分析"""
|
||||||
|
try:
|
||||||
|
# 构建分析 prompt
|
||||||
|
analysis_prompt = f"""分析以下代码中变量 '{variable_name}' 的数据流。
|
||||||
|
|
||||||
|
源代码:
|
||||||
|
```
|
||||||
|
{source_code}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if sink_code:
|
||||||
|
analysis_prompt += f"""
|
||||||
|
汇代码(可能的危险函数):
|
||||||
|
```
|
||||||
|
{sink_code}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
analysis_prompt += f"""
|
||||||
|
请分析:
|
||||||
|
1. 变量 '{variable_name}' 的来源是什么?(用户输入、配置、数据库等)
|
||||||
|
2. 变量在传递过程中是否经过了净化/验证?
|
||||||
|
3. 变量最终流向了哪些危险函数?
|
||||||
|
4. 是否存在安全风险?
|
||||||
|
|
||||||
|
请返回 JSON 格式的分析结果,包含:
|
||||||
|
- source_type: 数据源类型
|
||||||
|
- sanitized: 是否经过净化
|
||||||
|
- sanitization_methods: 使用的净化方法
|
||||||
|
- dangerous_sinks: 流向的危险函数列表
|
||||||
|
- risk_level: 风险等级 (high/medium/low/none)
|
||||||
|
- explanation: 详细解释
|
||||||
|
- recommendation: 建议
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用 LLM 分析
|
||||||
|
# 这里使用 analyze_code_with_custom_prompt
|
||||||
|
result = await self.llm_service.analyze_code_with_custom_prompt(
|
||||||
|
code=source_code,
|
||||||
|
language="text",
|
||||||
|
custom_prompt=analysis_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"📊 数据流分析结果 - 变量: {variable_name}\n"]
|
||||||
|
|
||||||
|
if isinstance(result, dict):
|
||||||
|
if result.get("source_type"):
|
||||||
|
output_parts.append(f"数据源: {result.get('source_type')}")
|
||||||
|
if result.get("sanitized") is not None:
|
||||||
|
sanitized = "✅ 是" if result.get("sanitized") else "❌ 否"
|
||||||
|
output_parts.append(f"是否净化: {sanitized}")
|
||||||
|
if result.get("sanitization_methods"):
|
||||||
|
output_parts.append(f"净化方法: {', '.join(result.get('sanitization_methods', []))}")
|
||||||
|
if result.get("dangerous_sinks"):
|
||||||
|
output_parts.append(f"危险函数: {', '.join(result.get('dangerous_sinks', []))}")
|
||||||
|
if result.get("risk_level"):
|
||||||
|
risk_icons = {"high": "🔴", "medium": "🟠", "low": "🟡", "none": "🟢"}
|
||||||
|
icon = risk_icons.get(result.get("risk_level", ""), "⚪")
|
||||||
|
output_parts.append(f"风险等级: {icon} {result.get('risk_level', '').upper()}")
|
||||||
|
if result.get("explanation"):
|
||||||
|
output_parts.append(f"\n分析: {result.get('explanation')}")
|
||||||
|
if result.get("recommendation"):
|
||||||
|
output_parts.append(f"\n建议: {result.get('recommendation')}")
|
||||||
|
else:
|
||||||
|
output_parts.append(str(result))
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"variable": variable_name,
|
||||||
|
"file_path": file_path,
|
||||||
|
"analysis": result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"数据流分析失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VulnerabilityValidationInput(BaseModel):
|
||||||
|
"""漏洞验证输入"""
|
||||||
|
code: str = Field(description="可能存在漏洞的代码")
|
||||||
|
vulnerability_type: str = Field(description="漏洞类型")
|
||||||
|
file_path: str = Field(default="unknown", description="文件路径")
|
||||||
|
line_number: Optional[int] = Field(default=None, description="行号")
|
||||||
|
context: Optional[str] = Field(default=None, description="额外上下文")
|
||||||
|
|
||||||
|
|
||||||
|
class VulnerabilityValidationTool(AgentTool):
|
||||||
|
"""
|
||||||
|
漏洞验证工具
|
||||||
|
验证疑似漏洞是否真实存在
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_service):
|
||||||
|
super().__init__()
|
||||||
|
self.llm_service = llm_service
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "vulnerability_validation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """验证疑似漏洞是否真实存在。
|
||||||
|
对发现的潜在漏洞进行深入分析,判断是否为真正的安全问题。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- code: 包含疑似漏洞的代码
|
||||||
|
- vulnerability_type: 漏洞类型(如 sql_injection, xss 等)
|
||||||
|
- file_path: 文件路径
|
||||||
|
- line_number: 可选,行号
|
||||||
|
- context: 可选,额外的上下文代码
|
||||||
|
|
||||||
|
输出:
|
||||||
|
- 验证结果(确认/可能/误报)
|
||||||
|
- 详细分析
|
||||||
|
- 利用条件
|
||||||
|
- PoC 思路(如果确认存在漏洞)"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return VulnerabilityValidationInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
vulnerability_type: str,
|
||||||
|
file_path: str = "unknown",
|
||||||
|
line_number: Optional[int] = None,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行漏洞验证"""
|
||||||
|
try:
|
||||||
|
validation_prompt = f"""你是一个专业的安全研究员,请验证以下代码中是否真的存在 {vulnerability_type} 漏洞。
|
||||||
|
|
||||||
|
代码:
|
||||||
|
```
|
||||||
|
{code}
|
||||||
|
```
|
||||||
|
|
||||||
|
{f'额外上下文:' + chr(10) + '```' + chr(10) + context + chr(10) + '```' if context else ''}
|
||||||
|
|
||||||
|
请分析:
|
||||||
|
1. 这段代码是否真的存在 {vulnerability_type} 漏洞?
|
||||||
|
2. 漏洞的利用条件是什么?
|
||||||
|
3. 攻击者如何利用这个漏洞?
|
||||||
|
4. 这是否可能是误报?为什么?
|
||||||
|
|
||||||
|
请返回 JSON 格式:
|
||||||
|
{{
|
||||||
|
"is_vulnerable": true/false/null (null表示无法确定),
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"verdict": "confirmed/likely/unlikely/false_positive",
|
||||||
|
"exploitation_conditions": ["条件1", "条件2"],
|
||||||
|
"attack_vector": "攻击向量描述",
|
||||||
|
"poc_idea": "PoC思路(如果存在漏洞)",
|
||||||
|
"false_positive_reason": "如果是误报,说明原因",
|
||||||
|
"detailed_analysis": "详细分析"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await self.llm_service.analyze_code_with_custom_prompt(
|
||||||
|
code=code,
|
||||||
|
language="text",
|
||||||
|
custom_prompt=validation_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"🔎 漏洞验证结果 - {vulnerability_type}\n"]
|
||||||
|
output_parts.append(f"文件: {file_path}")
|
||||||
|
if line_number:
|
||||||
|
output_parts.append(f"行号: {line_number}")
|
||||||
|
output_parts.append("")
|
||||||
|
|
||||||
|
if isinstance(result, dict):
|
||||||
|
# 验证结果
|
||||||
|
verdict_icons = {
|
||||||
|
"confirmed": "🔴 确认存在漏洞",
|
||||||
|
"likely": "🟠 可能存在漏洞",
|
||||||
|
"unlikely": "🟡 可能是误报",
|
||||||
|
"false_positive": "🟢 误报",
|
||||||
|
}
|
||||||
|
verdict = result.get("verdict", "unknown")
|
||||||
|
output_parts.append(f"判定: {verdict_icons.get(verdict, verdict)}")
|
||||||
|
|
||||||
|
if result.get("confidence"):
|
||||||
|
output_parts.append(f"置信度: {result.get('confidence') * 100:.0f}%")
|
||||||
|
|
||||||
|
if result.get("exploitation_conditions"):
|
||||||
|
output_parts.append(f"\n利用条件:")
|
||||||
|
for cond in result.get("exploitation_conditions", []):
|
||||||
|
output_parts.append(f" - {cond}")
|
||||||
|
|
||||||
|
if result.get("attack_vector"):
|
||||||
|
output_parts.append(f"\n攻击向量: {result.get('attack_vector')}")
|
||||||
|
|
||||||
|
if result.get("poc_idea") and verdict in ["confirmed", "likely"]:
|
||||||
|
output_parts.append(f"\nPoC思路: {result.get('poc_idea')}")
|
||||||
|
|
||||||
|
if result.get("false_positive_reason") and verdict in ["unlikely", "false_positive"]:
|
||||||
|
output_parts.append(f"\n误报原因: {result.get('false_positive_reason')}")
|
||||||
|
|
||||||
|
if result.get("detailed_analysis"):
|
||||||
|
output_parts.append(f"\n详细分析:\n{result.get('detailed_analysis')}")
|
||||||
|
else:
|
||||||
|
output_parts.append(str(result))
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"vulnerability_type": vulnerability_type,
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_number": line_number,
|
||||||
|
"validation": result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"漏洞验证失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,948 @@
|
||||||
|
"""
|
||||||
|
外部安全工具集成
|
||||||
|
集成 Semgrep、Bandit、Gitleaks、TruffleHog、npm audit 等专业安全工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Semgrep 工具 ============
|
||||||
|
|
||||||
|
class SemgrepInput(BaseModel):
|
||||||
|
"""Semgrep 扫描输入"""
|
||||||
|
target_path: str = Field(description="要扫描的目录或文件路径(相对于项目根目录)")
|
||||||
|
rules: Optional[str] = Field(
|
||||||
|
default="auto",
|
||||||
|
description="规则集: auto, p/security-audit, p/owasp-top-ten, p/r2c-security-audit, 或自定义规则文件路径"
|
||||||
|
)
|
||||||
|
severity: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="过滤严重程度: ERROR, WARNING, INFO"
|
||||||
|
)
|
||||||
|
max_results: int = Field(default=50, description="最大返回结果数")
|
||||||
|
|
||||||
|
|
||||||
|
class SemgrepTool(AgentTool):
|
||||||
|
"""
|
||||||
|
Semgrep 静态分析工具
|
||||||
|
|
||||||
|
Semgrep 是一款快速、轻量级的静态分析工具,支持多种编程语言。
|
||||||
|
提供丰富的安全规则库,可以检测各种安全漏洞。
|
||||||
|
|
||||||
|
官方规则集:
|
||||||
|
- p/security-audit: 综合安全审计
|
||||||
|
- p/owasp-top-ten: OWASP Top 10 漏洞
|
||||||
|
- p/r2c-security-audit: R2C 安全审计规则
|
||||||
|
- p/python: Python 特定规则
|
||||||
|
- p/javascript: JavaScript 特定规则
|
||||||
|
"""
|
||||||
|
|
||||||
|
AVAILABLE_RULESETS = [
|
||||||
|
"auto",
|
||||||
|
"p/security-audit",
|
||||||
|
"p/owasp-top-ten",
|
||||||
|
"p/r2c-security-audit",
|
||||||
|
"p/python",
|
||||||
|
"p/javascript",
|
||||||
|
"p/typescript",
|
||||||
|
"p/java",
|
||||||
|
"p/go",
|
||||||
|
"p/php",
|
||||||
|
"p/ruby",
|
||||||
|
"p/secrets",
|
||||||
|
"p/sql-injection",
|
||||||
|
"p/xss",
|
||||||
|
"p/command-injection",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "semgrep_scan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 Semgrep 进行静态安全分析。
|
||||||
|
Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。
|
||||||
|
|
||||||
|
可用规则集:
|
||||||
|
- auto: 自动选择最佳规则
|
||||||
|
- p/security-audit: 综合安全审计
|
||||||
|
- p/owasp-top-ten: OWASP Top 10 漏洞检测
|
||||||
|
- p/secrets: 密钥泄露检测
|
||||||
|
- p/sql-injection: SQL 注入检测
|
||||||
|
- p/xss: XSS 检测
|
||||||
|
- p/command-injection: 命令注入检测
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 快速全面的代码安全扫描
|
||||||
|
- 检测常见安全漏洞模式
|
||||||
|
- 遵循行业安全标准审计"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return SemgrepInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
target_path: str = ".",
|
||||||
|
rules: str = "auto",
|
||||||
|
severity: Optional[str] = None,
|
||||||
|
max_results: int = 50,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 Semgrep 扫描"""
|
||||||
|
# 检查 semgrep 是否可用
|
||||||
|
if not await self._check_semgrep():
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="Semgrep 未安装。请使用 'pip install semgrep' 安装。",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建完整路径
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||||
|
if not full_path.startswith(os.path.normpath(self.project_root)):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="安全错误:不允许扫描项目目录外的路径",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建命令
|
||||||
|
cmd = ["semgrep", "--json", "--quiet"]
|
||||||
|
|
||||||
|
if rules == "auto":
|
||||||
|
cmd.extend(["--config", "auto"])
|
||||||
|
elif rules.startswith("p/"):
|
||||||
|
cmd.extend(["--config", rules])
|
||||||
|
else:
|
||||||
|
cmd.extend(["--config", rules])
|
||||||
|
|
||||||
|
if severity:
|
||||||
|
cmd.extend(["--severity", severity])
|
||||||
|
|
||||||
|
cmd.append(full_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=self.project_root,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
|
||||||
|
|
||||||
|
if proc.returncode not in [0, 1]: # 1 means findings were found
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"Semgrep 执行失败: {stderr.decode()[:500]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
try:
|
||||||
|
results = json.loads(stdout.decode())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="无法解析 Semgrep 输出",
|
||||||
|
)
|
||||||
|
|
||||||
|
findings = results.get("results", [])[:max_results]
|
||||||
|
|
||||||
|
if not findings:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data=f"Semgrep 扫描完成,未发现安全问题 (规则集: {rules})",
|
||||||
|
metadata={"findings_count": 0, "rules": rules}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"🔍 Semgrep 扫描结果 (规则集: {rules})\n"]
|
||||||
|
output_parts.append(f"发现 {len(findings)} 个问题:\n")
|
||||||
|
|
||||||
|
severity_icons = {"ERROR": "🔴", "WARNING": "🟠", "INFO": "🟡"}
|
||||||
|
|
||||||
|
for i, finding in enumerate(findings[:max_results]):
|
||||||
|
sev = finding.get("extra", {}).get("severity", "INFO")
|
||||||
|
icon = severity_icons.get(sev, "⚪")
|
||||||
|
|
||||||
|
output_parts.append(f"\n{icon} [{sev}] {finding.get('check_id', 'unknown')}")
|
||||||
|
output_parts.append(f" 文件: {finding.get('path', '')}:{finding.get('start', {}).get('line', 0)}")
|
||||||
|
output_parts.append(f" 消息: {finding.get('extra', {}).get('message', '')[:200]}")
|
||||||
|
|
||||||
|
# 代码片段
|
||||||
|
lines = finding.get("extra", {}).get("lines", "")
|
||||||
|
if lines:
|
||||||
|
output_parts.append(f" 代码: {lines[:150]}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"findings_count": len(findings),
|
||||||
|
"rules": rules,
|
||||||
|
"findings": findings[:10],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return ToolResult(success=False, error="Semgrep 扫描超时")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"Semgrep 执行错误: {str(e)}")
|
||||||
|
|
||||||
|
async def _check_semgrep(self) -> bool:
|
||||||
|
"""检查 Semgrep 是否可用"""
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"semgrep", "--version",
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
await proc.communicate()
|
||||||
|
return proc.returncode == 0
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Bandit 工具 (Python) ============
|
||||||
|
|
||||||
|
class BanditInput(BaseModel):
|
||||||
|
"""Bandit 扫描输入"""
|
||||||
|
target_path: str = Field(default=".", description="要扫描的 Python 目录或文件")
|
||||||
|
severity: str = Field(default="medium", description="最低严重程度: low, medium, high")
|
||||||
|
confidence: str = Field(default="medium", description="最低置信度: low, medium, high")
|
||||||
|
max_results: int = Field(default=50, description="最大返回结果数")
|
||||||
|
|
||||||
|
|
||||||
|
class BanditTool(AgentTool):
|
||||||
|
"""
|
||||||
|
Bandit Python 安全扫描工具
|
||||||
|
|
||||||
|
Bandit 是专门用于 Python 代码的安全分析工具,
|
||||||
|
可以检测常见的 Python 安全问题,如:
|
||||||
|
- 硬编码密码
|
||||||
|
- SQL 注入
|
||||||
|
- 命令注入
|
||||||
|
- 不安全的随机数生成
|
||||||
|
- 不安全的反序列化
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "bandit_scan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 Bandit 扫描 Python 代码的安全问题。
|
||||||
|
Bandit 是 Python 专用的安全分析工具,由 OpenStack 安全团队开发。
|
||||||
|
|
||||||
|
检测项目:
|
||||||
|
- B101: assert 使用
|
||||||
|
- B102: exec 使用
|
||||||
|
- B103-B108: 文件权限问题
|
||||||
|
- B301-B312: pickle/yaml 反序列化
|
||||||
|
- B501-B508: SSL/TLS 问题
|
||||||
|
- B601-B608: shell/SQL 注入
|
||||||
|
- B701-B703: Jinja2 模板问题
|
||||||
|
|
||||||
|
仅适用于 Python 项目。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return BanditInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
target_path: str = ".",
|
||||||
|
severity: str = "medium",
|
||||||
|
confidence: str = "medium",
|
||||||
|
max_results: int = 50,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 Bandit 扫描"""
|
||||||
|
if not await self._check_bandit():
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="Bandit 未安装。请使用 'pip install bandit' 安装。",
|
||||||
|
)
|
||||||
|
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||||
|
if not full_path.startswith(os.path.normpath(self.project_root)):
|
||||||
|
return ToolResult(success=False, error="安全错误:路径越界")
|
||||||
|
|
||||||
|
# 构建命令
|
||||||
|
severity_map = {"low": "l", "medium": "m", "high": "h"}
|
||||||
|
confidence_map = {"low": "l", "medium": "m", "high": "h"}
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"bandit", "-r", "-f", "json",
|
||||||
|
"-ll" if severity == "low" else f"-l{severity_map.get(severity, 'm')}",
|
||||||
|
f"-i{confidence_map.get(confidence, 'm')}",
|
||||||
|
full_path
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = json.loads(stdout.decode())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return ToolResult(success=False, error="无法解析 Bandit 输出")
|
||||||
|
|
||||||
|
findings = results.get("results", [])[:max_results]
|
||||||
|
|
||||||
|
if not findings:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="Bandit 扫描完成,未发现 Python 安全问题",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = ["🐍 Bandit Python 安全扫描结果\n"]
|
||||||
|
output_parts.append(f"发现 {len(findings)} 个问题:\n")
|
||||||
|
|
||||||
|
severity_icons = {"HIGH": "🔴", "MEDIUM": "🟠", "LOW": "🟡"}
|
||||||
|
|
||||||
|
for finding in findings:
|
||||||
|
sev = finding.get("issue_severity", "LOW")
|
||||||
|
icon = severity_icons.get(sev, "⚪")
|
||||||
|
|
||||||
|
output_parts.append(f"\n{icon} [{sev}] {finding.get('test_id', '')}: {finding.get('test_name', '')}")
|
||||||
|
output_parts.append(f" 文件: {finding.get('filename', '')}:{finding.get('line_number', 0)}")
|
||||||
|
output_parts.append(f" 消息: {finding.get('issue_text', '')[:200]}")
|
||||||
|
output_parts.append(f" 代码: {finding.get('code', '')[:100]}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={"findings_count": len(findings), "findings": findings[:10]}
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return ToolResult(success=False, error="Bandit 扫描超时")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"Bandit 执行错误: {str(e)}")
|
||||||
|
|
||||||
|
async def _check_bandit(self) -> bool:
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"bandit", "--version",
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
await proc.communicate()
|
||||||
|
return proc.returncode == 0
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Gitleaks 工具 ============
|
||||||
|
|
||||||
|
class GitleaksInput(BaseModel):
|
||||||
|
"""Gitleaks 扫描输入"""
|
||||||
|
target_path: str = Field(default=".", description="要扫描的目录")
|
||||||
|
no_git: bool = Field(default=True, description="不使用 git history,仅扫描文件")
|
||||||
|
max_results: int = Field(default=50, description="最大返回结果数")
|
||||||
|
|
||||||
|
|
||||||
|
class GitleaksTool(AgentTool):
|
||||||
|
"""
|
||||||
|
Gitleaks 密钥泄露检测工具
|
||||||
|
|
||||||
|
Gitleaks 是一款专门用于检测代码中硬编码密钥的工具。
|
||||||
|
可以检测:
|
||||||
|
- API Keys (AWS, GCP, Azure, GitHub, etc.)
|
||||||
|
- 私钥 (RSA, SSH, PGP)
|
||||||
|
- 数据库凭据
|
||||||
|
- OAuth tokens
|
||||||
|
- JWT secrets
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "gitleaks_scan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 Gitleaks 检测代码中的密钥泄露。
|
||||||
|
Gitleaks 是专业的密钥检测工具,支持 150+ 种密钥类型。
|
||||||
|
|
||||||
|
检测类型:
|
||||||
|
- AWS Access Keys / Secret Keys
|
||||||
|
- GCP API Keys / Service Account Keys
|
||||||
|
- Azure Credentials
|
||||||
|
- GitHub / GitLab Tokens
|
||||||
|
- Private Keys (RSA, SSH, PGP)
|
||||||
|
- Database Connection Strings
|
||||||
|
- JWT Secrets
|
||||||
|
- Slack / Discord Tokens
|
||||||
|
- 等等...
|
||||||
|
|
||||||
|
建议在代码审计早期使用此工具。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return GitleaksInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
target_path: str = ".",
|
||||||
|
no_git: bool = True,
|
||||||
|
max_results: int = 50,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 Gitleaks 扫描"""
|
||||||
|
if not await self._check_gitleaks():
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="Gitleaks 未安装。请从 https://github.com/gitleaks/gitleaks 安装。",
|
||||||
|
)
|
||||||
|
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||||
|
if not full_path.startswith(os.path.normpath(self.project_root)):
|
||||||
|
return ToolResult(success=False, error="安全错误:路径越界")
|
||||||
|
|
||||||
|
cmd = ["gitleaks", "detect", "--source", full_path, "-f", "json"]
|
||||||
|
if no_git:
|
||||||
|
cmd.append("--no-git")
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
|
||||||
|
|
||||||
|
# Gitleaks returns 1 if secrets found
|
||||||
|
if proc.returncode not in [0, 1]:
|
||||||
|
return ToolResult(success=False, error=f"Gitleaks 执行失败: {stderr.decode()[:300]}")
|
||||||
|
|
||||||
|
if not stdout.strip():
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="🔐 Gitleaks 扫描完成,未发现密钥泄露",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
findings = json.loads(stdout.decode())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
if not findings:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="🔐 Gitleaks 扫描完成,未发现密钥泄露",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
findings = findings[:max_results]
|
||||||
|
|
||||||
|
output_parts = ["🔐 Gitleaks 密钥泄露检测结果\n"]
|
||||||
|
output_parts.append(f"⚠️ 发现 {len(findings)} 处密钥泄露!\n")
|
||||||
|
|
||||||
|
for i, finding in enumerate(findings):
|
||||||
|
output_parts.append(f"\n🔴 [{i+1}] {finding.get('RuleID', 'unknown')}")
|
||||||
|
output_parts.append(f" 描述: {finding.get('Description', '')}")
|
||||||
|
output_parts.append(f" 文件: {finding.get('File', '')}:{finding.get('StartLine', 0)}")
|
||||||
|
|
||||||
|
# 部分遮盖密钥
|
||||||
|
secret = finding.get('Secret', '')
|
||||||
|
if len(secret) > 8:
|
||||||
|
masked = secret[:4] + '*' * (len(secret) - 8) + secret[-4:]
|
||||||
|
else:
|
||||||
|
masked = '*' * len(secret)
|
||||||
|
output_parts.append(f" 密钥: {masked}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"findings_count": len(findings),
|
||||||
|
"findings": [
|
||||||
|
{"rule": f.get("RuleID"), "file": f.get("File"), "line": f.get("StartLine")}
|
||||||
|
for f in findings[:10]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return ToolResult(success=False, error="Gitleaks 扫描超时")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"Gitleaks 执行错误: {str(e)}")
|
||||||
|
|
||||||
|
async def _check_gitleaks(self) -> bool:
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"gitleaks", "version",
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
await proc.communicate()
|
||||||
|
return proc.returncode == 0
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ============ npm audit 工具 ============
|
||||||
|
|
||||||
|
class NpmAuditInput(BaseModel):
|
||||||
|
"""npm audit 扫描输入"""
|
||||||
|
target_path: str = Field(default=".", description="包含 package.json 的目录")
|
||||||
|
production_only: bool = Field(default=False, description="仅扫描生产依赖")
|
||||||
|
|
||||||
|
|
||||||
|
class NpmAuditTool(AgentTool):
|
||||||
|
"""
|
||||||
|
npm audit 依赖漏洞扫描工具
|
||||||
|
|
||||||
|
扫描 Node.js 项目的依赖漏洞,基于 npm 官方漏洞数据库。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "npm_audit"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 npm audit 扫描 Node.js 项目的依赖漏洞。
|
||||||
|
基于 npm 官方漏洞数据库,检测已知的依赖安全问题。
|
||||||
|
|
||||||
|
适用于:
|
||||||
|
- 包含 package.json 的 Node.js 项目
|
||||||
|
- 前端项目 (React, Vue, Angular 等)
|
||||||
|
|
||||||
|
需要先运行 npm install 安装依赖。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return NpmAuditInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
target_path: str = ".",
|
||||||
|
production_only: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 npm audit"""
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||||
|
|
||||||
|
# 检查 package.json
|
||||||
|
package_json = os.path.join(full_path, "package.json")
|
||||||
|
if not os.path.exists(package_json):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"未找到 package.json: {target_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = ["npm", "audit", "--json"]
|
||||||
|
if production_only:
|
||||||
|
cmd.append("--production")
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=full_path,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = json.loads(stdout.decode())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return ToolResult(success=True, data="npm audit 输出为空或格式错误")
|
||||||
|
|
||||||
|
vulnerabilities = results.get("vulnerabilities", {})
|
||||||
|
|
||||||
|
if not vulnerabilities:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="📦 npm audit 完成,未发现依赖漏洞",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = ["📦 npm audit 依赖漏洞扫描结果\n"]
|
||||||
|
|
||||||
|
severity_counts = {"critical": 0, "high": 0, "moderate": 0, "low": 0}
|
||||||
|
for name, vuln in vulnerabilities.items():
|
||||||
|
severity = vuln.get("severity", "low")
|
||||||
|
severity_counts[severity] = severity_counts.get(severity, 0) + 1
|
||||||
|
|
||||||
|
output_parts.append(f"漏洞统计: 🔴 Critical: {severity_counts['critical']}, 🟠 High: {severity_counts['high']}, 🟡 Moderate: {severity_counts['moderate']}, 🟢 Low: {severity_counts['low']}\n")
|
||||||
|
|
||||||
|
severity_icons = {"critical": "🔴", "high": "🟠", "moderate": "🟡", "low": "🟢"}
|
||||||
|
|
||||||
|
for name, vuln in list(vulnerabilities.items())[:20]:
|
||||||
|
sev = vuln.get("severity", "low")
|
||||||
|
icon = severity_icons.get(sev, "⚪")
|
||||||
|
output_parts.append(f"\n{icon} [{sev.upper()}] {name}")
|
||||||
|
output_parts.append(f" 版本范围: {vuln.get('range', 'unknown')}")
|
||||||
|
|
||||||
|
via = vuln.get("via", [])
|
||||||
|
if via and isinstance(via[0], dict):
|
||||||
|
output_parts.append(f" 来源: {via[0].get('title', '')[:100]}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"findings_count": len(vulnerabilities),
|
||||||
|
"severity_counts": severity_counts,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return ToolResult(success=False, error="npm audit 超时")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"npm audit 错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Safety 工具 (Python 依赖) ============
|
||||||
|
|
||||||
|
class SafetyInput(BaseModel):
|
||||||
|
"""Safety 扫描输入"""
|
||||||
|
requirements_file: str = Field(default="requirements.txt", description="requirements 文件路径")
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyTool(AgentTool):
|
||||||
|
"""
|
||||||
|
Safety Python 依赖漏洞扫描工具
|
||||||
|
|
||||||
|
检查 Python 依赖中的已知安全漏洞。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "safety_scan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 Safety 扫描 Python 依赖的安全漏洞。
|
||||||
|
基于 PyUp.io 漏洞数据库检测已知的依赖安全问题。
|
||||||
|
|
||||||
|
适用于:
|
||||||
|
- 包含 requirements.txt 的 Python 项目
|
||||||
|
- Pipenv 项目 (Pipfile.lock)
|
||||||
|
- Poetry 项目 (poetry.lock)"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return SafetyInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
requirements_file: str = "requirements.txt",
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 Safety 扫描"""
|
||||||
|
full_path = os.path.join(self.project_root, requirements_file)
|
||||||
|
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
return ToolResult(success=False, error=f"未找到依赖文件: {requirements_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"safety", "check", "-r", full_path, "--json",
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Safety 输出的 JSON 格式可能不同版本有差异
|
||||||
|
output = stdout.decode()
|
||||||
|
if "No known security" in output:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="🐍 Safety 扫描完成,未发现 Python 依赖漏洞",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = json.loads(output)
|
||||||
|
except:
|
||||||
|
return ToolResult(success=True, data=f"Safety 输出:\n{stdout.decode()[:1000]}")
|
||||||
|
|
||||||
|
vulnerabilities = results if isinstance(results, list) else results.get("vulnerabilities", [])
|
||||||
|
|
||||||
|
if not vulnerabilities:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="🐍 Safety 扫描完成,未发现 Python 依赖漏洞",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = ["🐍 Safety Python 依赖漏洞扫描结果\n"]
|
||||||
|
output_parts.append(f"发现 {len(vulnerabilities)} 个漏洞:\n")
|
||||||
|
|
||||||
|
for vuln in vulnerabilities[:20]:
|
||||||
|
if isinstance(vuln, list) and len(vuln) >= 4:
|
||||||
|
output_parts.append(f"\n🔴 {vuln[0]} ({vuln[1]})")
|
||||||
|
output_parts.append(f" 漏洞 ID: {vuln[4] if len(vuln) > 4 else 'N/A'}")
|
||||||
|
output_parts.append(f" 描述: {vuln[3][:200] if len(vuln) > 3 else ''}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={"findings_count": len(vulnerabilities)}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"Safety 执行错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ TruffleHog 工具 ============
|
||||||
|
|
||||||
|
class TruffleHogInput(BaseModel):
|
||||||
|
"""TruffleHog 扫描输入"""
|
||||||
|
target_path: str = Field(default=".", description="要扫描的目录")
|
||||||
|
only_verified: bool = Field(default=False, description="仅显示已验证的密钥")
|
||||||
|
|
||||||
|
|
||||||
|
class TruffleHogTool(AgentTool):
|
||||||
|
"""
|
||||||
|
TruffleHog 深度密钥扫描工具
|
||||||
|
|
||||||
|
TruffleHog 可以检测代码和 Git 历史中的密钥泄露,
|
||||||
|
并可以验证密钥是否仍然有效。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "trufflehog_scan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 TruffleHog 进行深度密钥扫描。
|
||||||
|
TruffleHog 可以扫描代码和 Git 历史,并验证密钥是否有效。
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 支持 700+ 种密钥类型
|
||||||
|
- 可以验证密钥是否仍然有效
|
||||||
|
- 扫描 Git 历史记录
|
||||||
|
- 高精度,低误报
|
||||||
|
|
||||||
|
建议与 Gitleaks 配合使用以获得最佳效果。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return TruffleHogInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
target_path: str = ".",
|
||||||
|
only_verified: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 TruffleHog 扫描"""
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||||
|
|
||||||
|
cmd = ["trufflehog", "filesystem", full_path, "--json"]
|
||||||
|
if only_verified:
|
||||||
|
cmd.append("--only-verified")
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=180)
|
||||||
|
|
||||||
|
if not stdout.strip():
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="🔍 TruffleHog 扫描完成,未发现密钥泄露",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
# TruffleHog 输出每行一个 JSON 对象
|
||||||
|
findings = []
|
||||||
|
for line in stdout.decode().strip().split('\n'):
|
||||||
|
if line.strip():
|
||||||
|
try:
|
||||||
|
findings.append(json.loads(line))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not findings:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="🔍 TruffleHog 扫描完成,未发现密钥泄露",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = ["🔍 TruffleHog 密钥扫描结果\n"]
|
||||||
|
output_parts.append(f"⚠️ 发现 {len(findings)} 处密钥泄露!\n")
|
||||||
|
|
||||||
|
for i, finding in enumerate(findings[:20]):
|
||||||
|
verified = "✅ 已验证有效" if finding.get("Verified") else "⚠️ 未验证"
|
||||||
|
output_parts.append(f"\n🔴 [{i+1}] {finding.get('DetectorName', 'unknown')} - {verified}")
|
||||||
|
output_parts.append(f" 文件: {finding.get('SourceMetadata', {}).get('Data', {}).get('Filesystem', {}).get('file', '')}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={"findings_count": len(findings)}
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return ToolResult(success=False, error="TruffleHog 扫描超时")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"TruffleHog 执行错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ OSV-Scanner 工具 ============
|
||||||
|
|
||||||
|
class OSVScannerInput(BaseModel):
|
||||||
|
"""OSV-Scanner 扫描输入"""
|
||||||
|
target_path: str = Field(default=".", description="要扫描的项目目录")
|
||||||
|
|
||||||
|
|
||||||
|
class OSVScannerTool(AgentTool):
|
||||||
|
"""
|
||||||
|
OSV-Scanner 开源漏洞扫描工具
|
||||||
|
|
||||||
|
Google 开源的漏洞扫描工具,使用 OSV 数据库。
|
||||||
|
支持多种包管理器和锁文件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "osv_scan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """使用 OSV-Scanner 扫描开源依赖漏洞。
|
||||||
|
Google 开源的漏洞扫描工具,使用 OSV (Open Source Vulnerabilities) 数据库。
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- package.json / package-lock.json (npm)
|
||||||
|
- requirements.txt / Pipfile.lock (Python)
|
||||||
|
- go.mod / go.sum (Go)
|
||||||
|
- Cargo.lock (Rust)
|
||||||
|
- pom.xml (Maven)
|
||||||
|
- Gemfile.lock (Ruby)
|
||||||
|
- composer.lock (PHP)
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 覆盖多种语言和包管理器
|
||||||
|
- 使用 Google 维护的漏洞数据库
|
||||||
|
- 快速、准确"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return OSVScannerInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
target_path: str = ".",
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 OSV-Scanner"""
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"osv-scanner", "--json", "-r", full_path,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = json.loads(stdout.decode())
|
||||||
|
except:
|
||||||
|
if "no package sources found" in stdout.decode().lower():
|
||||||
|
return ToolResult(success=True, data="OSV-Scanner: 未找到可扫描的包文件")
|
||||||
|
return ToolResult(success=True, data=f"OSV-Scanner 输出:\n{stdout.decode()[:1000]}")
|
||||||
|
|
||||||
|
vulns = results.get("results", [])
|
||||||
|
|
||||||
|
if not vulns:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="📋 OSV-Scanner 扫描完成,未发现依赖漏洞",
|
||||||
|
metadata={"findings_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
total_vulns = sum(len(r.get("vulnerabilities", [])) for r in vulns)
|
||||||
|
|
||||||
|
output_parts = ["📋 OSV-Scanner 开源漏洞扫描结果\n"]
|
||||||
|
output_parts.append(f"发现 {total_vulns} 个漏洞:\n")
|
||||||
|
|
||||||
|
for result in vulns[:10]:
|
||||||
|
source = result.get("source", {}).get("path", "unknown")
|
||||||
|
for vuln in result.get("vulnerabilities", [])[:5]:
|
||||||
|
vuln_id = vuln.get("id", "")
|
||||||
|
summary = vuln.get("summary", "")[:100]
|
||||||
|
output_parts.append(f"\n🔴 {vuln_id}")
|
||||||
|
output_parts.append(f" 来源: {source}")
|
||||||
|
output_parts.append(f" 描述: {summary}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={"findings_count": total_vulns}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(success=False, error=f"OSV-Scanner 执行错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 导出所有工具 ============
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SemgrepTool",
|
||||||
|
"BanditTool",
|
||||||
|
"GitleaksTool",
|
||||||
|
"NpmAuditTool",
|
||||||
|
"SafetyTool",
|
||||||
|
"TruffleHogTool",
|
||||||
|
"OSVScannerTool",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,481 @@
|
||||||
|
"""
|
||||||
|
文件操作工具
|
||||||
|
读取和搜索代码文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import fnmatch
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class FileReadInput(BaseModel):
|
||||||
|
"""文件读取输入"""
|
||||||
|
file_path: str = Field(description="文件路径(相对于项目根目录)")
|
||||||
|
start_line: Optional[int] = Field(default=None, description="起始行号(从1开始)")
|
||||||
|
end_line: Optional[int] = Field(default=None, description="结束行号")
|
||||||
|
max_lines: int = Field(default=500, description="最大返回行数")
|
||||||
|
|
||||||
|
|
||||||
|
class FileReadTool(AgentTool):
|
||||||
|
"""
|
||||||
|
文件读取工具
|
||||||
|
读取项目中的文件内容
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
"""
|
||||||
|
初始化文件读取工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: 项目根目录
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "read_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """读取项目中的文件内容。
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 查看完整的源代码文件
|
||||||
|
- 查看特定行范围的代码
|
||||||
|
- 获取配置文件内容
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- file_path: 文件路径(相对于项目根目录)
|
||||||
|
- start_line: 可选,起始行号
|
||||||
|
- end_line: 可选,结束行号
|
||||||
|
- max_lines: 最大返回行数(默认500)
|
||||||
|
|
||||||
|
注意: 为避免输出过长,建议指定行范围或使用 RAG 搜索定位代码。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return FileReadInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
start_line: Optional[int] = None,
|
||||||
|
end_line: Optional[int] = None,
|
||||||
|
max_lines: int = 500,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行文件读取"""
|
||||||
|
try:
|
||||||
|
# 安全检查:防止路径遍历
|
||||||
|
full_path = os.path.normpath(os.path.join(self.project_root, file_path))
|
||||||
|
if not full_path.startswith(os.path.normpath(self.project_root)):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="安全错误:不允许访问项目目录外的文件",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"文件不存在: {file_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.isfile(full_path):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"不是文件: {file_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查文件大小
|
||||||
|
file_size = os.path.getsize(full_path)
|
||||||
|
if file_size > 1024 * 1024: # 1MB
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"文件过大 ({file_size / 1024:.1f}KB),请指定行范围",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取文件
|
||||||
|
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# 处理行范围
|
||||||
|
if start_line is not None:
|
||||||
|
start_idx = max(0, start_line - 1)
|
||||||
|
else:
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
if end_line is not None:
|
||||||
|
end_idx = min(total_lines, end_line)
|
||||||
|
else:
|
||||||
|
end_idx = min(total_lines, start_idx + max_lines)
|
||||||
|
|
||||||
|
# 截取指定行
|
||||||
|
selected_lines = lines[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 添加行号
|
||||||
|
numbered_lines = []
|
||||||
|
for i, line in enumerate(selected_lines, start=start_idx + 1):
|
||||||
|
numbered_lines.append(f"{i:4d}| {line.rstrip()}")
|
||||||
|
|
||||||
|
content = '\n'.join(numbered_lines)
|
||||||
|
|
||||||
|
# 检测语言
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
language = {
|
||||||
|
".py": "python", ".js": "javascript", ".ts": "typescript",
|
||||||
|
".java": "java", ".go": "go", ".rs": "rust",
|
||||||
|
".cpp": "cpp", ".c": "c", ".cs": "csharp",
|
||||||
|
".php": "php", ".rb": "ruby", ".swift": "swift",
|
||||||
|
}.get(ext, "text")
|
||||||
|
|
||||||
|
output = f"📄 文件: {file_path}\n"
|
||||||
|
output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n"
|
||||||
|
output += f"```{language}\n{content}\n```"
|
||||||
|
|
||||||
|
if end_idx < total_lines:
|
||||||
|
output += f"\n\n... 还有 {total_lines - end_idx} 行未显示"
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data=output,
|
||||||
|
metadata={
|
||||||
|
"file_path": file_path,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"start_line": start_idx + 1,
|
||||||
|
"end_line": end_idx,
|
||||||
|
"language": language,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"读取文件失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileSearchInput(BaseModel):
|
||||||
|
"""文件搜索输入"""
|
||||||
|
keyword: str = Field(description="搜索关键字或正则表达式")
|
||||||
|
file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js")
|
||||||
|
directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)")
|
||||||
|
case_sensitive: bool = Field(default=False, description="是否区分大小写")
|
||||||
|
max_results: int = Field(default=50, description="最大结果数")
|
||||||
|
is_regex: bool = Field(default=False, description="是否使用正则表达式")
|
||||||
|
|
||||||
|
|
||||||
|
class FileSearchTool(AgentTool):
|
||||||
|
"""
|
||||||
|
文件搜索工具
|
||||||
|
在项目中搜索包含特定内容的代码
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 排除的目录
|
||||||
|
EXCLUDE_DIRS = {
|
||||||
|
"node_modules", "vendor", "dist", "build", ".git",
|
||||||
|
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
|
||||||
|
".vscode", ".idea", ".vs", "target", "venv", "env",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "search_code"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """在项目代码中搜索关键字或模式。
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 查找特定函数的所有调用位置
|
||||||
|
- 搜索特定的 API 使用
|
||||||
|
- 查找包含特定模式的代码
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- keyword: 搜索关键字或正则表达式
|
||||||
|
- file_pattern: 可选,文件名模式(如 *.py)
|
||||||
|
- directory: 可选,搜索目录
|
||||||
|
- case_sensitive: 是否区分大小写
|
||||||
|
- is_regex: 是否使用正则表达式
|
||||||
|
|
||||||
|
这是一个快速搜索工具,结果包含匹配行和上下文。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return FileSearchInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
keyword: str,
|
||||||
|
file_pattern: Optional[str] = None,
|
||||||
|
directory: Optional[str] = None,
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
max_results: int = 50,
|
||||||
|
is_regex: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行文件搜索"""
|
||||||
|
try:
|
||||||
|
# 确定搜索目录
|
||||||
|
if directory:
|
||||||
|
search_dir = os.path.normpath(os.path.join(self.project_root, directory))
|
||||||
|
if not search_dir.startswith(os.path.normpath(self.project_root)):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="安全错误:不允许搜索项目目录外的内容",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
search_dir = self.project_root
|
||||||
|
|
||||||
|
# 编译搜索模式
|
||||||
|
flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
try:
|
||||||
|
if is_regex:
|
||||||
|
pattern = re.compile(keyword, flags)
|
||||||
|
else:
|
||||||
|
pattern = re.compile(re.escape(keyword), flags)
|
||||||
|
except re.error as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"无效的搜索模式: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
files_searched = 0
|
||||||
|
|
||||||
|
# 遍历文件
|
||||||
|
for root, dirs, files in os.walk(search_dir):
|
||||||
|
# 排除目录
|
||||||
|
dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS]
|
||||||
|
|
||||||
|
for filename in files:
|
||||||
|
# 检查文件模式
|
||||||
|
if file_pattern and not fnmatch.fnmatch(filename, file_pattern):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path = os.path.join(root, filename)
|
||||||
|
relative_path = os.path.relpath(file_path, self.project_root)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
files_searched += 1
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if pattern.search(line):
|
||||||
|
# 获取上下文
|
||||||
|
start = max(0, i - 1)
|
||||||
|
end = min(len(lines), i + 2)
|
||||||
|
context_lines = []
|
||||||
|
for j in range(start, end):
|
||||||
|
prefix = ">" if j == i else " "
|
||||||
|
context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"file": relative_path,
|
||||||
|
"line": i + 1,
|
||||||
|
"match": line.strip()[:200],
|
||||||
|
"context": '\n'.join(context_lines),
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(results) >= max_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(results) >= max_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(results) >= max_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件",
|
||||||
|
metadata={"files_searched": files_searched, "matches": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"🔍 搜索结果: '{keyword}'\n"]
|
||||||
|
output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n")
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
output_parts.append(f"\n📄 {result['file']}:{result['line']}")
|
||||||
|
output_parts.append(f"```\n{result['context']}\n```")
|
||||||
|
|
||||||
|
if len(results) >= max_results:
|
||||||
|
output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"keyword": keyword,
|
||||||
|
"files_searched": files_searched,
|
||||||
|
"matches": len(results),
|
||||||
|
"results": results[:10], # 只在元数据中保留前10个
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"搜索失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ListFilesInput(BaseModel):
|
||||||
|
"""列出文件输入"""
|
||||||
|
directory: str = Field(default=".", description="目录路径(相对于项目根目录)")
|
||||||
|
pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py")
|
||||||
|
recursive: bool = Field(default=False, description="是否递归列出子目录")
|
||||||
|
max_files: int = Field(default=100, description="最大文件数")
|
||||||
|
|
||||||
|
|
||||||
|
class ListFilesTool(AgentTool):
|
||||||
|
"""
|
||||||
|
列出文件工具
|
||||||
|
列出目录中的文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXCLUDE_DIRS = {
|
||||||
|
"node_modules", "vendor", "dist", "build", ".git",
|
||||||
|
"__pycache__", ".pytest_cache", "coverage",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, project_root: str):
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """列出目录中的文件。
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 了解项目结构
|
||||||
|
- 查找特定类型的文件
|
||||||
|
- 浏览目录内容
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- directory: 目录路径
|
||||||
|
- pattern: 可选,文件名模式
|
||||||
|
- recursive: 是否递归
|
||||||
|
- max_files: 最大文件数"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return ListFilesInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
directory: str = ".",
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
recursive: bool = False,
|
||||||
|
max_files: int = 100,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行文件列表"""
|
||||||
|
try:
|
||||||
|
target_dir = os.path.normpath(os.path.join(self.project_root, directory))
|
||||||
|
if not target_dir.startswith(os.path.normpath(self.project_root)):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="安全错误:不允许访问项目目录外的目录",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"目录不存在: {directory}",
|
||||||
|
)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
dirs = []
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
for root, dirnames, filenames in os.walk(target_dir):
|
||||||
|
# 排除目录
|
||||||
|
dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if pattern and not fnmatch.fnmatch(filename, pattern):
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_path = os.path.join(root, filename)
|
||||||
|
relative_path = os.path.relpath(full_path, self.project_root)
|
||||||
|
files.append(relative_path)
|
||||||
|
|
||||||
|
if len(files) >= max_files:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(files) >= max_files:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for item in os.listdir(target_dir):
|
||||||
|
if item in self.EXCLUDE_DIRS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_path = os.path.join(target_dir, item)
|
||||||
|
relative_path = os.path.relpath(full_path, self.project_root)
|
||||||
|
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
dirs.append(relative_path + "/")
|
||||||
|
else:
|
||||||
|
if pattern and not fnmatch.fnmatch(item, pattern):
|
||||||
|
continue
|
||||||
|
files.append(relative_path)
|
||||||
|
|
||||||
|
if len(files) >= max_files:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"📁 目录: {directory}\n"]
|
||||||
|
|
||||||
|
if dirs:
|
||||||
|
output_parts.append("目录:")
|
||||||
|
for d in sorted(dirs)[:20]:
|
||||||
|
output_parts.append(f" 📂 {d}")
|
||||||
|
if len(dirs) > 20:
|
||||||
|
output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录")
|
||||||
|
|
||||||
|
if files:
|
||||||
|
output_parts.append(f"\n文件 ({len(files)}):")
|
||||||
|
for f in sorted(files):
|
||||||
|
output_parts.append(f" 📄 {f}")
|
||||||
|
|
||||||
|
if len(files) >= max_files:
|
||||||
|
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"directory": directory,
|
||||||
|
"file_count": len(files),
|
||||||
|
"dir_count": len(dirs),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"列出文件失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,418 @@
|
||||||
|
"""
|
||||||
|
模式匹配工具
|
||||||
|
快速扫描代码中的危险模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PatternMatch:
|
||||||
|
"""模式匹配结果"""
|
||||||
|
pattern_name: str
|
||||||
|
pattern_type: str
|
||||||
|
file_path: str
|
||||||
|
line_number: int
|
||||||
|
matched_text: str
|
||||||
|
context: str
|
||||||
|
severity: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class PatternMatchInput(BaseModel):
|
||||||
|
"""模式匹配输入"""
|
||||||
|
code: str = Field(description="要扫描的代码内容")
|
||||||
|
file_path: str = Field(default="unknown", description="文件路径")
|
||||||
|
pattern_types: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="要检测的漏洞类型列表,如 ['sql_injection', 'xss']。为空则检测所有类型"
|
||||||
|
)
|
||||||
|
language: Optional[str] = Field(default=None, description="编程语言,用于选择特定模式")
|
||||||
|
|
||||||
|
|
||||||
|
class PatternMatchTool(AgentTool):
|
||||||
|
"""
|
||||||
|
模式匹配工具
|
||||||
|
使用正则表达式快速扫描代码中的危险模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 危险模式定义
|
||||||
|
PATTERNS: Dict[str, Dict[str, Any]] = {
|
||||||
|
# SQL 注入模式
|
||||||
|
"sql_injection": {
|
||||||
|
"patterns": {
|
||||||
|
"python": [
|
||||||
|
(r'cursor\.execute\s*\(\s*["\'].*%[sd].*["\'].*%', "格式化字符串构造SQL"),
|
||||||
|
(r'cursor\.execute\s*\(\s*f["\']', "f-string构造SQL"),
|
||||||
|
(r'cursor\.execute\s*\([^,)]+\+', "字符串拼接构造SQL"),
|
||||||
|
(r'\.execute\s*\(\s*["\'][^"\']*\{', "format()构造SQL"),
|
||||||
|
(r'text\s*\(\s*["\'].*\+.*["\']', "SQLAlchemy text()拼接"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r'\.query\s*\(\s*[`"\'].*\$\{', "模板字符串构造SQL"),
|
||||||
|
(r'\.query\s*\(\s*["\'].*\+', "字符串拼接构造SQL"),
|
||||||
|
(r'mysql\.query\s*\([^,)]+\+', "MySQL查询拼接"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'Statement.*execute.*\+', "Statement字符串拼接"),
|
||||||
|
(r'createQuery\s*\([^,)]+\+', "JPA查询拼接"),
|
||||||
|
(r'\.executeQuery\s*\([^,)]+\+', "executeQuery拼接"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'mysql_query\s*\(\s*["\'].*\.\s*\$', "mysql_query拼接"),
|
||||||
|
(r'mysqli_query\s*\([^,]+,\s*["\'].*\.\s*\$', "mysqli_query拼接"),
|
||||||
|
(r'\$pdo->query\s*\(\s*["\'].*\.\s*\$', "PDO query拼接"),
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
(r'\.Query\s*\([^,)]+\+', "Query字符串拼接"),
|
||||||
|
(r'\.Exec\s*\([^,)]+\+', "Exec字符串拼接"),
|
||||||
|
(r'fmt\.Sprintf\s*\([^)]+\)\s*\)', "Sprintf构造SQL"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "high",
|
||||||
|
"description": "SQL注入漏洞:用户输入直接拼接到SQL语句中",
|
||||||
|
},
|
||||||
|
|
||||||
|
# XSS 模式
|
||||||
|
"xss": {
|
||||||
|
"patterns": {
|
||||||
|
"javascript": [
|
||||||
|
(r'innerHTML\s*=\s*[^;]+', "innerHTML赋值"),
|
||||||
|
(r'outerHTML\s*=\s*[^;]+', "outerHTML赋值"),
|
||||||
|
(r'document\.write\s*\(', "document.write"),
|
||||||
|
(r'\.html\s*\([^)]+\)', "jQuery html()"),
|
||||||
|
(r'dangerouslySetInnerHTML', "React dangerouslySetInnerHTML"),
|
||||||
|
],
|
||||||
|
"python": [
|
||||||
|
(r'\|\s*safe\b', "Django safe过滤器"),
|
||||||
|
(r'Markup\s*\(', "Flask Markup"),
|
||||||
|
(r'mark_safe\s*\(', "Django mark_safe"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'echo\s+\$_(?:GET|POST|REQUEST)', "直接输出用户输入"),
|
||||||
|
(r'print\s+\$_(?:GET|POST|REQUEST)', "打印用户输入"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'out\.print(?:ln)?\s*\([^)]*request\.getParameter', "直接输出请求参数"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "high",
|
||||||
|
"description": "XSS跨站脚本漏洞:未转义的用户输入被渲染到页面",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 命令注入模式
|
||||||
|
"command_injection": {
|
||||||
|
"patterns": {
|
||||||
|
"python": [
|
||||||
|
(r'os\.system\s*\([^)]*\+', "os.system拼接"),
|
||||||
|
(r'os\.system\s*\([^)]*%', "os.system格式化"),
|
||||||
|
(r'os\.system\s*\(\s*f["\']', "os.system f-string"),
|
||||||
|
(r'subprocess\.(?:call|run|Popen)\s*\([^)]*shell\s*=\s*True', "shell=True"),
|
||||||
|
(r'subprocess\.(?:call|run|Popen)\s*\(\s*["\'][^"\']+%', "subprocess格式化"),
|
||||||
|
(r'eval\s*\(', "eval()"),
|
||||||
|
(r'exec\s*\(', "exec()"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r'exec\s*\([^)]+\+', "exec拼接"),
|
||||||
|
(r'spawn\s*\([^)]+,\s*\{[^}]*shell:\s*true', "spawn shell"),
|
||||||
|
(r'eval\s*\(', "eval()"),
|
||||||
|
(r'Function\s*\(', "Function构造器"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'exec\s*\(\s*\$', "exec变量"),
|
||||||
|
(r'system\s*\(\s*\$', "system变量"),
|
||||||
|
(r'passthru\s*\(\s*\$', "passthru变量"),
|
||||||
|
(r'shell_exec\s*\(\s*\$', "shell_exec变量"),
|
||||||
|
(r'`[^`]*\$[^`]*`', "反引号命令执行"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'Runtime\.getRuntime\(\)\.exec\s*\([^)]+\+', "Runtime.exec拼接"),
|
||||||
|
(r'ProcessBuilder[^;]+\+', "ProcessBuilder拼接"),
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
(r'exec\.Command\s*\([^)]+\+', "exec.Command拼接"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "critical",
|
||||||
|
"description": "命令注入漏洞:用户输入被用于执行系统命令",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 路径遍历模式
|
||||||
|
"path_traversal": {
|
||||||
|
"patterns": {
|
||||||
|
"python": [
|
||||||
|
(r'open\s*\([^)]*\+', "open()拼接"),
|
||||||
|
(r'open\s*\([^)]*%', "open()格式化"),
|
||||||
|
(r'os\.path\.join\s*\([^)]*request', "join用户输入"),
|
||||||
|
(r'send_file\s*\([^)]*request', "send_file用户输入"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r'fs\.read(?:File|FileSync)\s*\([^)]+\+', "readFile拼接"),
|
||||||
|
(r'path\.join\s*\([^)]*req\.', "path.join用户输入"),
|
||||||
|
(r'res\.sendFile\s*\([^)]+\+', "sendFile拼接"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'include\s*\(\s*\$', "include变量"),
|
||||||
|
(r'require\s*\(\s*\$', "require变量"),
|
||||||
|
(r'file_get_contents\s*\(\s*\$', "file_get_contents变量"),
|
||||||
|
(r'fopen\s*\(\s*\$', "fopen变量"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'new\s+File\s*\([^)]+request\.getParameter', "File构造用户输入"),
|
||||||
|
(r'new\s+FileInputStream\s*\([^)]+\+', "FileInputStream拼接"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "high",
|
||||||
|
"description": "路径遍历漏洞:用户可以访问任意文件",
|
||||||
|
},
|
||||||
|
|
||||||
|
# SSRF 模式
|
||||||
|
"ssrf": {
|
||||||
|
"patterns": {
|
||||||
|
"python": [
|
||||||
|
(r'requests\.(?:get|post|put|delete)\s*\([^)]*request\.', "requests用户URL"),
|
||||||
|
(r'urllib\.request\.urlopen\s*\([^)]*request\.', "urlopen用户URL"),
|
||||||
|
(r'httpx\.(?:get|post)\s*\([^)]*request\.', "httpx用户URL"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r'fetch\s*\([^)]*req\.', "fetch用户URL"),
|
||||||
|
(r'axios\.(?:get|post)\s*\([^)]*req\.', "axios用户URL"),
|
||||||
|
(r'http\.request\s*\([^)]*req\.', "http.request用户URL"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'new\s+URL\s*\([^)]*request\.getParameter', "URL构造用户输入"),
|
||||||
|
(r'HttpClient[^;]+request\.getParameter', "HttpClient用户URL"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'curl_setopt[^;]+CURLOPT_URL[^;]+\$', "curl用户URL"),
|
||||||
|
(r'file_get_contents\s*\(\s*\$_', "file_get_contents用户URL"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "high",
|
||||||
|
"description": "SSRF漏洞:服务端请求用户控制的URL",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 不安全的反序列化
|
||||||
|
"deserialization": {
|
||||||
|
"patterns": {
|
||||||
|
"python": [
|
||||||
|
(r'pickle\.loads?\s*\(', "pickle反序列化"),
|
||||||
|
(r'yaml\.load\s*\([^)]*(?!Loader)', "yaml.load无安全Loader"),
|
||||||
|
(r'yaml\.unsafe_load\s*\(', "yaml.unsafe_load"),
|
||||||
|
(r'marshal\.loads?\s*\(', "marshal反序列化"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r'serialize\s*\(', "serialize"),
|
||||||
|
(r'unserialize\s*\(', "unserialize"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'ObjectInputStream\s*\(', "ObjectInputStream"),
|
||||||
|
(r'XMLDecoder\s*\(', "XMLDecoder"),
|
||||||
|
(r'readObject\s*\(', "readObject"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'unserialize\s*\(\s*\$', "unserialize用户输入"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "critical",
|
||||||
|
"description": "不安全的反序列化:可能导致远程代码执行",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 硬编码密钥
|
||||||
|
"hardcoded_secret": {
|
||||||
|
"patterns": {
|
||||||
|
"_common": [
|
||||||
|
(r'(?:password|passwd|pwd)\s*=\s*["\'][^"\']{4,}["\']', "硬编码密码"),
|
||||||
|
(r'(?:secret|api_?key|apikey|token|auth)\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"),
|
||||||
|
(r'(?:private_?key|priv_?key)\s*=\s*["\'][^"\']+["\']', "硬编码私钥"),
|
||||||
|
(r'-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----', "私钥"),
|
||||||
|
(r'(?:aws_?access_?key|aws_?secret)\s*=\s*["\'][^"\']+["\']', "AWS密钥"),
|
||||||
|
(r'(?:ghp_|gho_|github_pat_)[a-zA-Z0-9]{36,}', "GitHub Token"),
|
||||||
|
(r'sk-[a-zA-Z0-9]{48}', "OpenAI API Key"),
|
||||||
|
(r'(?:bearer|authorization)\s*[=:]\s*["\'][^"\']{20,}["\']', "Bearer Token"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "medium",
|
||||||
|
"description": "硬编码密钥:敏感信息不应该硬编码在代码中",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 弱加密
|
||||||
|
"weak_crypto": {
|
||||||
|
"patterns": {
|
||||||
|
"python": [
|
||||||
|
(r'hashlib\.md5\s*\(', "MD5哈希"),
|
||||||
|
(r'hashlib\.sha1\s*\(', "SHA1哈希"),
|
||||||
|
(r'DES\s*\(', "DES加密"),
|
||||||
|
(r'random\.random\s*\(', "不安全随机数"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r'crypto\.createHash\s*\(\s*["\']md5["\']', "MD5哈希"),
|
||||||
|
(r'crypto\.createHash\s*\(\s*["\']sha1["\']', "SHA1哈希"),
|
||||||
|
(r'Math\.random\s*\(', "Math.random"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r'MessageDigest\.getInstance\s*\(\s*["\']MD5["\']', "MD5哈希"),
|
||||||
|
(r'MessageDigest\.getInstance\s*\(\s*["\']SHA-?1["\']', "SHA1哈希"),
|
||||||
|
(r'DESKeySpec', "DES密钥"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r'md5\s*\(', "MD5哈希"),
|
||||||
|
(r'sha1\s*\(', "SHA1哈希"),
|
||||||
|
(r'mcrypt_', "mcrypt已废弃"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"severity": "low",
|
||||||
|
"description": "弱加密算法:使用了不安全的加密或哈希算法",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "pattern_match"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
vuln_types = ", ".join(self.PATTERNS.keys())
|
||||||
|
return f"""快速扫描代码中的危险模式和常见漏洞。
|
||||||
|
使用正则表达式检测已知的不安全代码模式。
|
||||||
|
|
||||||
|
支持的漏洞类型: {vuln_types}
|
||||||
|
|
||||||
|
这是一个快速扫描工具,可以在分析开始时使用来快速发现潜在问题。
|
||||||
|
发现的问题需要进一步分析确认。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return PatternMatchInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
file_path: str = "unknown",
|
||||||
|
pattern_types: Optional[List[str]] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行模式匹配"""
|
||||||
|
matches: List[PatternMatch] = []
|
||||||
|
lines = code.split('\n')
|
||||||
|
|
||||||
|
# 确定要检查的漏洞类型
|
||||||
|
types_to_check = pattern_types or list(self.PATTERNS.keys())
|
||||||
|
|
||||||
|
# 自动检测语言
|
||||||
|
if not language:
|
||||||
|
language = self._detect_language(file_path)
|
||||||
|
|
||||||
|
for vuln_type in types_to_check:
|
||||||
|
if vuln_type not in self.PATTERNS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pattern_config = self.PATTERNS[vuln_type]
|
||||||
|
patterns_dict = pattern_config["patterns"]
|
||||||
|
|
||||||
|
# 获取语言特定模式和通用模式
|
||||||
|
patterns_to_use = []
|
||||||
|
if language and language in patterns_dict:
|
||||||
|
patterns_to_use.extend(patterns_dict[language])
|
||||||
|
if "_common" in patterns_dict:
|
||||||
|
patterns_to_use.extend(patterns_dict["_common"])
|
||||||
|
|
||||||
|
# 如果没有特定语言模式,尝试使用所有模式
|
||||||
|
if not patterns_to_use:
|
||||||
|
for lang, pats in patterns_dict.items():
|
||||||
|
if lang != "_common":
|
||||||
|
patterns_to_use.extend(pats)
|
||||||
|
|
||||||
|
# 执行匹配
|
||||||
|
for pattern, pattern_name in patterns_to_use:
|
||||||
|
try:
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if re.search(pattern, line, re.IGNORECASE):
|
||||||
|
# 获取上下文
|
||||||
|
start = max(0, i - 2)
|
||||||
|
end = min(len(lines), i + 3)
|
||||||
|
context = '\n'.join(f"{j+1}: {lines[j]}" for j in range(start, end))
|
||||||
|
|
||||||
|
matches.append(PatternMatch(
|
||||||
|
pattern_name=pattern_name,
|
||||||
|
pattern_type=vuln_type,
|
||||||
|
file_path=file_path,
|
||||||
|
line_number=i + 1,
|
||||||
|
matched_text=line.strip()[:200],
|
||||||
|
context=context,
|
||||||
|
severity=pattern_config["severity"],
|
||||||
|
description=pattern_config["description"],
|
||||||
|
))
|
||||||
|
except re.error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="没有检测到已知的危险模式",
|
||||||
|
metadata={"patterns_checked": len(types_to_check), "matches": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"⚠️ 检测到 {len(matches)} 个潜在问题:\n"]
|
||||||
|
|
||||||
|
# 按严重程度排序
|
||||||
|
severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
|
||||||
|
matches.sort(key=lambda x: severity_order.get(x.severity, 4))
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
severity_icon = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🟢"}.get(match.severity, "⚪")
|
||||||
|
output_parts.append(f"\n{severity_icon} [{match.severity.upper()}] {match.pattern_type}")
|
||||||
|
output_parts.append(f" 位置: {match.file_path}:{match.line_number}")
|
||||||
|
output_parts.append(f" 模式: {match.pattern_name}")
|
||||||
|
output_parts.append(f" 描述: {match.description}")
|
||||||
|
output_parts.append(f" 匹配: {match.matched_text}")
|
||||||
|
output_parts.append(f" 上下文:\n{match.context}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"matches": len(matches),
|
||||||
|
"by_severity": {
|
||||||
|
s: len([m for m in matches if m.severity == s])
|
||||||
|
for s in ["critical", "high", "medium", "low"]
|
||||||
|
},
|
||||||
|
"details": [
|
||||||
|
{
|
||||||
|
"type": m.pattern_type,
|
||||||
|
"severity": m.severity,
|
||||||
|
"line": m.line_number,
|
||||||
|
"pattern": m.pattern_name,
|
||||||
|
}
|
||||||
|
for m in matches
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_language(self, file_path: str) -> Optional[str]:
|
||||||
|
"""根据文件扩展名检测语言"""
|
||||||
|
ext_map = {
|
||||||
|
".py": "python",
|
||||||
|
".js": "javascript",
|
||||||
|
".jsx": "javascript",
|
||||||
|
".ts": "javascript",
|
||||||
|
".tsx": "javascript",
|
||||||
|
".java": "java",
|
||||||
|
".php": "php",
|
||||||
|
".go": "go",
|
||||||
|
".rb": "ruby",
|
||||||
|
}
|
||||||
|
|
||||||
|
for ext, lang in ext_map.items():
|
||||||
|
if file_path.lower().endswith(ext):
|
||||||
|
return lang
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
@ -0,0 +1,293 @@
|
||||||
|
"""
|
||||||
|
RAG 检索工具
|
||||||
|
支持语义检索代码
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
from app.services.rag import CodeRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class RAGQueryInput(BaseModel):
|
||||||
|
"""RAG 查询输入参数"""
|
||||||
|
query: str = Field(description="搜索查询,描述你要找的代码功能或特征")
|
||||||
|
top_k: int = Field(default=10, description="返回结果数量")
|
||||||
|
file_path: Optional[str] = Field(default=None, description="限定搜索的文件路径")
|
||||||
|
language: Optional[str] = Field(default=None, description="限定编程语言")
|
||||||
|
|
||||||
|
|
||||||
|
class RAGQueryTool(AgentTool):
|
||||||
|
"""
|
||||||
|
RAG 代码检索工具
|
||||||
|
使用语义搜索在代码库中查找相关代码
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, retriever: CodeRetriever):
|
||||||
|
super().__init__()
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "rag_query"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """在代码库中进行语义搜索。
|
||||||
|
使用场景:
|
||||||
|
- 查找特定功能的实现代码
|
||||||
|
- 查找调用某个函数的代码
|
||||||
|
- 查找处理用户输入的代码
|
||||||
|
- 查找数据库操作相关代码
|
||||||
|
- 查找认证/授权相关代码
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- query: 描述你要查找的代码,例如 "处理用户登录的函数"、"SQL查询执行"、"文件上传处理"
|
||||||
|
- top_k: 返回结果数量(默认10)
|
||||||
|
- file_path: 可选,限定在某个文件中搜索
|
||||||
|
- language: 可选,限定编程语言
|
||||||
|
|
||||||
|
输出: 相关的代码片段列表,包含文件路径、行号、代码内容和相似度分数"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return RAGQueryInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
file_path: Optional[str] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 RAG 检索"""
|
||||||
|
try:
|
||||||
|
results = await self.retriever.retrieve(
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
filter_file_path=file_path,
|
||||||
|
filter_language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="没有找到相关代码",
|
||||||
|
metadata={"query": query, "results_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"找到 {len(results)} 个相关代码片段:\n"]
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
output_parts.append(f"\n--- 结果 {i+1} (相似度: {result.score:.2f}) ---")
|
||||||
|
output_parts.append(f"文件: {result.file_path}")
|
||||||
|
output_parts.append(f"行号: {result.line_start}-{result.line_end}")
|
||||||
|
if result.name:
|
||||||
|
output_parts.append(f"名称: {result.name}")
|
||||||
|
if result.security_indicators:
|
||||||
|
output_parts.append(f"安全指标: {', '.join(result.security_indicators)}")
|
||||||
|
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"query": query,
|
||||||
|
"results_count": len(results),
|
||||||
|
"results": [r.to_dict() for r in results],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"RAG 检索失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityCodeSearchInput(BaseModel):
|
||||||
|
"""安全代码搜索输入"""
|
||||||
|
vulnerability_type: str = Field(
|
||||||
|
description="漏洞类型: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, auth_bypass, hardcoded_secret"
|
||||||
|
)
|
||||||
|
top_k: int = Field(default=20, description="返回结果数量")
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityCodeSearchTool(AgentTool):
|
||||||
|
"""
|
||||||
|
安全相关代码搜索工具
|
||||||
|
专门用于查找可能存在安全漏洞的代码
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, retriever: CodeRetriever):
|
||||||
|
super().__init__()
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "security_code_search"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """搜索可能存在安全漏洞的代码。
|
||||||
|
专门针对特定漏洞类型进行搜索。
|
||||||
|
|
||||||
|
支持的漏洞类型:
|
||||||
|
- sql_injection: SQL 注入
|
||||||
|
- xss: 跨站脚本
|
||||||
|
- command_injection: 命令注入
|
||||||
|
- path_traversal: 路径遍历
|
||||||
|
- ssrf: 服务端请求伪造
|
||||||
|
- deserialization: 不安全的反序列化
|
||||||
|
- auth_bypass: 认证绕过
|
||||||
|
- hardcoded_secret: 硬编码密钥"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return SecurityCodeSearchInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
vulnerability_type: str,
|
||||||
|
top_k: int = 20,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行安全代码搜索"""
|
||||||
|
try:
|
||||||
|
results = await self.retriever.retrieve_security_related(
|
||||||
|
vulnerability_type=vulnerability_type,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data=f"没有找到与 {vulnerability_type} 相关的代码",
|
||||||
|
metadata={"vulnerability_type": vulnerability_type, "results_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = [f"找到 {len(results)} 个可能与 {vulnerability_type} 相关的代码:\n"]
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
output_parts.append(f"\n--- 可疑代码 {i+1} ---")
|
||||||
|
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
|
||||||
|
if result.security_indicators:
|
||||||
|
output_parts.append(f"⚠️ 安全指标: {', '.join(result.security_indicators)}")
|
||||||
|
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"vulnerability_type": vulnerability_type,
|
||||||
|
"results_count": len(results),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"安全代码搜索失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionContextInput(BaseModel):
|
||||||
|
"""函数上下文搜索输入"""
|
||||||
|
function_name: str = Field(description="函数名称")
|
||||||
|
file_path: Optional[str] = Field(default=None, description="文件路径")
|
||||||
|
include_callers: bool = Field(default=True, description="是否包含调用者")
|
||||||
|
include_callees: bool = Field(default=True, description="是否包含被调用的函数")
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionContextTool(AgentTool):
|
||||||
|
"""
|
||||||
|
函数上下文搜索工具
|
||||||
|
查找函数的定义、调用者和被调用者
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, retriever: CodeRetriever):
|
||||||
|
super().__init__()
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "function_context"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """查找函数的上下文信息,包括定义、调用者和被调用的函数。
|
||||||
|
用于追踪数据流和理解函数的使用方式。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- function_name: 要查找的函数名
|
||||||
|
- file_path: 可选,限定文件路径
|
||||||
|
- include_callers: 是否查找调用此函数的代码
|
||||||
|
- include_callees: 是否查找此函数调用的其他函数"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return FunctionContextInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
function_name: str,
|
||||||
|
file_path: Optional[str] = None,
|
||||||
|
include_callers: bool = True,
|
||||||
|
include_callees: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行函数上下文搜索"""
|
||||||
|
try:
|
||||||
|
context = await self.retriever.retrieve_function_context(
|
||||||
|
function_name=function_name,
|
||||||
|
file_path=file_path,
|
||||||
|
include_callers=include_callers,
|
||||||
|
include_callees=include_callees,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = [f"函数 '{function_name}' 的上下文分析:\n"]
|
||||||
|
|
||||||
|
# 函数定义
|
||||||
|
if context["definition"]:
|
||||||
|
output_parts.append("### 函数定义:")
|
||||||
|
for result in context["definition"]:
|
||||||
|
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
|
||||||
|
output_parts.append(f"```{result.language}\n{result.content}\n```")
|
||||||
|
else:
|
||||||
|
output_parts.append("未找到函数定义")
|
||||||
|
|
||||||
|
# 调用者
|
||||||
|
if context["callers"]:
|
||||||
|
output_parts.append(f"\n### 调用此函数的代码 ({len(context['callers'])} 处):")
|
||||||
|
for result in context["callers"][:5]:
|
||||||
|
output_parts.append(f"- {result.file_path}:{result.line_start}")
|
||||||
|
output_parts.append(f"```{result.language}\n{result.content[:500]}\n```")
|
||||||
|
|
||||||
|
# 被调用者
|
||||||
|
if context["callees"]:
|
||||||
|
output_parts.append(f"\n### 此函数调用的其他函数:")
|
||||||
|
for result in context["callees"][:5]:
|
||||||
|
if result.name:
|
||||||
|
output_parts.append(f"- {result.name} ({result.file_path})")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"function_name": function_name,
|
||||||
|
"definition_count": len(context["definition"]),
|
||||||
|
"callers_count": len(context["callers"]),
|
||||||
|
"callees_count": len(context["callees"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"函数上下文搜索失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,647 @@
|
||||||
|
"""
|
||||||
|
沙箱执行工具
|
||||||
|
在 Docker 沙箱中执行代码和命令进行漏洞验证
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import AgentTool, ToolResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SandboxConfig:
|
||||||
|
"""沙箱配置"""
|
||||||
|
image: str = "python:3.11-slim"
|
||||||
|
memory_limit: str = "512m"
|
||||||
|
cpu_limit: float = 1.0
|
||||||
|
timeout: int = 60
|
||||||
|
network_mode: str = "none" # none, bridge, host
|
||||||
|
read_only: bool = True
|
||||||
|
user: str = "1000:1000"
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxManager:
|
||||||
|
"""
|
||||||
|
沙箱管理器
|
||||||
|
管理 Docker 容器的创建、执行和清理
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||||
|
self.config = config or SandboxConfig()
|
||||||
|
self._docker_client = None
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化 Docker 客户端"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import docker
|
||||||
|
self._docker_client = docker.from_env()
|
||||||
|
# 测试连接
|
||||||
|
self._docker_client.ping()
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Docker sandbox manager initialized")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Docker not available: {e}")
|
||||||
|
self._docker_client = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""检查 Docker 是否可用"""
|
||||||
|
return self._docker_client is not None
|
||||||
|
|
||||||
|
async def execute_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
working_dir: Optional[str] = None,
|
||||||
|
env: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
在沙箱中执行命令
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: 要执行的命令
|
||||||
|
working_dir: 工作目录
|
||||||
|
env: 环境变量
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
if not self.is_available:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Docker 不可用",
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"exit_code": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout = timeout or self.config.timeout
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建临时目录
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# 准备容器配置
|
||||||
|
container_config = {
|
||||||
|
"image": self.config.image,
|
||||||
|
"command": ["sh", "-c", command],
|
||||||
|
"detach": True,
|
||||||
|
"mem_limit": self.config.memory_limit,
|
||||||
|
"cpu_period": 100000,
|
||||||
|
"cpu_quota": int(100000 * self.config.cpu_limit),
|
||||||
|
"network_mode": self.config.network_mode,
|
||||||
|
"user": self.config.user,
|
||||||
|
"read_only": self.config.read_only,
|
||||||
|
"volumes": {
|
||||||
|
temp_dir: {"bind": "/workspace", "mode": "rw"},
|
||||||
|
},
|
||||||
|
"working_dir": working_dir or "/workspace",
|
||||||
|
"environment": env or {},
|
||||||
|
# 安全配置
|
||||||
|
"cap_drop": ["ALL"],
|
||||||
|
"security_opt": ["no-new-privileges:true"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建并启动容器
|
||||||
|
container = await asyncio.to_thread(
|
||||||
|
self._docker_client.containers.run,
|
||||||
|
**container_config
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 等待执行完成
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(container.wait),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取日志
|
||||||
|
stdout = await asyncio.to_thread(
|
||||||
|
container.logs, stdout=True, stderr=False
|
||||||
|
)
|
||||||
|
stderr = await asyncio.to_thread(
|
||||||
|
container.logs, stdout=False, stderr=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": result["StatusCode"] == 0,
|
||||||
|
"stdout": stdout.decode('utf-8', errors='ignore')[:10000],
|
||||||
|
"stderr": stderr.decode('utf-8', errors='ignore')[:2000],
|
||||||
|
"exit_code": result["StatusCode"],
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await asyncio.to_thread(container.kill)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"执行超时 ({timeout}秒)",
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"exit_code": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理容器
|
||||||
|
await asyncio.to_thread(container.remove, force=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Sandbox execution error: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"exit_code": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute_python(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
在沙箱中执行 Python 代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Python 代码
|
||||||
|
timeout: 超时时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
# 转义代码中的单引号
|
||||||
|
escaped_code = code.replace("'", "'\\''")
|
||||||
|
command = f"python3 -c '{escaped_code}'"
|
||||||
|
return await self.execute_command(command, timeout=timeout)
|
||||||
|
|
||||||
|
async def execute_http_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
data: Optional[str] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
在沙箱中执行 HTTP 请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP 方法
|
||||||
|
url: URL
|
||||||
|
headers: 请求头
|
||||||
|
data: 请求体
|
||||||
|
timeout: 超时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTTP 响应
|
||||||
|
"""
|
||||||
|
# 构建 curl 命令
|
||||||
|
curl_parts = ["curl", "-s", "-S", "-w", "'\\n%{http_code}'", "-X", method]
|
||||||
|
|
||||||
|
if headers:
|
||||||
|
for key, value in headers.items():
|
||||||
|
curl_parts.extend(["-H", f"'{key}: {value}'"])
|
||||||
|
|
||||||
|
if data:
|
||||||
|
curl_parts.extend(["-d", f"'{data}'"])
|
||||||
|
|
||||||
|
curl_parts.append(f"'{url}'")
|
||||||
|
|
||||||
|
command = " ".join(curl_parts)
|
||||||
|
|
||||||
|
# 使用带网络的镜像
|
||||||
|
original_network = self.config.network_mode
|
||||||
|
self.config.network_mode = "bridge" # 允许网络访问
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.execute_command(command, timeout=timeout)
|
||||||
|
|
||||||
|
if result["success"] and result["stdout"]:
|
||||||
|
lines = result["stdout"].strip().split('\n')
|
||||||
|
if lines:
|
||||||
|
status_code = lines[-1].strip()
|
||||||
|
body = '\n'.join(lines[:-1])
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"status_code": int(status_code) if status_code.isdigit() else 0,
|
||||||
|
"body": body[:5000],
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"status_code": 0,
|
||||||
|
"body": "",
|
||||||
|
"error": result.get("error") or result.get("stderr"),
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.config.network_mode = original_network
|
||||||
|
|
||||||
|
async def verify_vulnerability(
|
||||||
|
self,
|
||||||
|
vulnerability_type: str,
|
||||||
|
target_url: str,
|
||||||
|
payload: str,
|
||||||
|
expected_pattern: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
验证漏洞
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vulnerability_type: 漏洞类型
|
||||||
|
target_url: 目标 URL
|
||||||
|
payload: 攻击载荷
|
||||||
|
expected_pattern: 期望在响应中匹配的模式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证结果
|
||||||
|
"""
|
||||||
|
verification_result = {
|
||||||
|
"vulnerability_type": vulnerability_type,
|
||||||
|
"target_url": target_url,
|
||||||
|
"payload": payload,
|
||||||
|
"is_vulnerable": False,
|
||||||
|
"evidence": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送请求
|
||||||
|
response = await self.execute_http_request(
|
||||||
|
method="GET" if "?" in target_url else "POST",
|
||||||
|
url=target_url,
|
||||||
|
data=payload if "?" not in target_url else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response["success"]:
|
||||||
|
verification_result["error"] = response.get("error")
|
||||||
|
return verification_result
|
||||||
|
|
||||||
|
body = response.get("body", "")
|
||||||
|
status_code = response.get("status_code", 0)
|
||||||
|
|
||||||
|
# 检查响应
|
||||||
|
if expected_pattern:
|
||||||
|
import re
|
||||||
|
if re.search(expected_pattern, body, re.IGNORECASE):
|
||||||
|
verification_result["is_vulnerable"] = True
|
||||||
|
verification_result["evidence"] = f"响应中包含预期模式: {expected_pattern}"
|
||||||
|
else:
|
||||||
|
# 根据漏洞类型进行通用检查
|
||||||
|
if vulnerability_type == "sql_injection":
|
||||||
|
error_patterns = [
|
||||||
|
r"SQL syntax",
|
||||||
|
r"mysql_fetch",
|
||||||
|
r"ORA-\d+",
|
||||||
|
r"PostgreSQL.*ERROR",
|
||||||
|
r"SQLite.*error",
|
||||||
|
r"ODBC.*Driver",
|
||||||
|
]
|
||||||
|
for pattern in error_patterns:
|
||||||
|
if re.search(pattern, body, re.IGNORECASE):
|
||||||
|
verification_result["is_vulnerable"] = True
|
||||||
|
verification_result["evidence"] = f"SQL错误信息: {pattern}"
|
||||||
|
break
|
||||||
|
|
||||||
|
elif vulnerability_type == "xss":
|
||||||
|
if payload in body:
|
||||||
|
verification_result["is_vulnerable"] = True
|
||||||
|
verification_result["evidence"] = "XSS payload 被反射到响应中"
|
||||||
|
|
||||||
|
elif vulnerability_type == "command_injection":
|
||||||
|
# 检查命令执行结果
|
||||||
|
if "uid=" in body or "root:" in body:
|
||||||
|
verification_result["is_vulnerable"] = True
|
||||||
|
verification_result["evidence"] = "命令执行成功"
|
||||||
|
|
||||||
|
verification_result["response_status"] = status_code
|
||||||
|
verification_result["response_length"] = len(body)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verification_result["error"] = str(e)
|
||||||
|
|
||||||
|
return verification_result
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxCommandInput(BaseModel):
|
||||||
|
"""沙箱命令输入"""
|
||||||
|
command: str = Field(description="要执行的命令")
|
||||||
|
timeout: int = Field(default=30, description="超时时间(秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxTool(AgentTool):
|
||||||
|
"""
|
||||||
|
沙箱执行工具
|
||||||
|
在安全隔离的环境中执行代码和命令
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 允许的命令前缀
|
||||||
|
ALLOWED_COMMANDS = [
|
||||||
|
"python", "python3", "node", "curl", "wget",
|
||||||
|
"cat", "head", "tail", "grep", "find", "ls",
|
||||||
|
"echo", "printf", "test", "id", "whoami",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, sandbox_manager: Optional[SandboxManager] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.sandbox_manager = sandbox_manager or SandboxManager()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "sandbox_exec"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """在安全沙箱中执行命令或代码。
|
||||||
|
用于验证漏洞、测试 PoC 或执行安全检查。
|
||||||
|
|
||||||
|
⚠️ 安全限制:
|
||||||
|
- 命令在 Docker 容器中执行
|
||||||
|
- 网络默认隔离
|
||||||
|
- 资源有限制
|
||||||
|
- 只允许特定命令
|
||||||
|
|
||||||
|
允许的命令: python, python3, node, curl, cat, grep, find, ls, echo, id
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 验证命令注入漏洞
|
||||||
|
- 执行 PoC 代码
|
||||||
|
- 测试 payload 效果"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return SandboxCommandInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
timeout: int = 30,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行沙箱命令"""
|
||||||
|
# 初始化沙箱
|
||||||
|
await self.sandbox_manager.initialize()
|
||||||
|
|
||||||
|
if not self.sandbox_manager.is_available:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="沙箱环境不可用(Docker 未安装或未运行)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 安全检查:验证命令是否允许
|
||||||
|
cmd_parts = command.strip().split()
|
||||||
|
if not cmd_parts:
|
||||||
|
return ToolResult(success=False, error="命令不能为空")
|
||||||
|
|
||||||
|
base_cmd = cmd_parts[0]
|
||||||
|
if not any(base_cmd.startswith(allowed) for allowed in self.ALLOWED_COMMANDS):
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"命令 '{base_cmd}' 不在允许列表中。允许的命令: {', '.join(self.ALLOWED_COMMANDS)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行命令
|
||||||
|
result = await self.sandbox_manager.execute_command(
|
||||||
|
command=command,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化输出
|
||||||
|
output_parts = ["🐳 沙箱执行结果\n"]
|
||||||
|
output_parts.append(f"命令: {command}")
|
||||||
|
output_parts.append(f"退出码: {result['exit_code']}")
|
||||||
|
|
||||||
|
if result["stdout"]:
|
||||||
|
output_parts.append(f"\n标准输出:\n```\n{result['stdout']}\n```")
|
||||||
|
|
||||||
|
if result["stderr"]:
|
||||||
|
output_parts.append(f"\n标准错误:\n```\n{result['stderr']}\n```")
|
||||||
|
|
||||||
|
if result.get("error"):
|
||||||
|
output_parts.append(f"\n错误: {result['error']}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=result["success"],
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
error=result.get("error"),
|
||||||
|
metadata={
|
||||||
|
"command": command,
|
||||||
|
"exit_code": result["exit_code"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpRequestInput(BaseModel):
|
||||||
|
"""HTTP 请求输入"""
|
||||||
|
method: str = Field(default="GET", description="HTTP 方法 (GET, POST, PUT, DELETE)")
|
||||||
|
url: str = Field(description="请求 URL")
|
||||||
|
headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
|
||||||
|
data: Optional[str] = Field(default=None, description="请求体")
|
||||||
|
timeout: int = Field(default=30, description="超时时间(秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxHttpTool(AgentTool):
|
||||||
|
"""
|
||||||
|
沙箱 HTTP 请求工具
|
||||||
|
在沙箱中发送 HTTP 请求
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sandbox_manager: Optional[SandboxManager] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.sandbox_manager = sandbox_manager or SandboxManager()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "sandbox_http"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """在沙箱中发送 HTTP 请求。
|
||||||
|
用于测试 Web 漏洞如 SQL 注入、XSS、SSRF 等。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- method: HTTP 方法
|
||||||
|
- url: 请求 URL
|
||||||
|
- headers: 可选,请求头
|
||||||
|
- data: 可选,请求体
|
||||||
|
- timeout: 超时时间
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 验证 SQL 注入漏洞
|
||||||
|
- 测试 XSS payload
|
||||||
|
- 验证 SSRF 漏洞
|
||||||
|
- 测试认证绕过"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return HttpRequestInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
method: str = "GET",
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
data: Optional[str] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行 HTTP 请求"""
|
||||||
|
await self.sandbox_manager.initialize()
|
||||||
|
|
||||||
|
if not self.sandbox_manager.is_available:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="沙箱环境不可用",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.sandbox_manager.execute_http_request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = ["🌐 HTTP 请求结果\n"]
|
||||||
|
output_parts.append(f"请求: {method} {url}")
|
||||||
|
|
||||||
|
if headers:
|
||||||
|
output_parts.append(f"请求头: {json.dumps(headers, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
if data:
|
||||||
|
output_parts.append(f"请求体: {data[:500]}")
|
||||||
|
|
||||||
|
output_parts.append(f"\n状态码: {result.get('status_code', 'N/A')}")
|
||||||
|
|
||||||
|
if result.get("body"):
|
||||||
|
body = result["body"]
|
||||||
|
if len(body) > 2000:
|
||||||
|
body = body[:2000] + f"\n... (截断,共 {len(result['body'])} 字符)"
|
||||||
|
output_parts.append(f"\n响应内容:\n```\n{body}\n```")
|
||||||
|
|
||||||
|
if result.get("error"):
|
||||||
|
output_parts.append(f"\n错误: {result['error']}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=result["success"],
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
error=result.get("error"),
|
||||||
|
metadata={
|
||||||
|
"method": method,
|
||||||
|
"url": url,
|
||||||
|
"status_code": result.get("status_code"),
|
||||||
|
"response_length": len(result.get("body", "")),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VulnerabilityVerifyInput(BaseModel):
|
||||||
|
"""漏洞验证输入"""
|
||||||
|
vulnerability_type: str = Field(description="漏洞类型 (sql_injection, xss, command_injection, etc.)")
|
||||||
|
target_url: str = Field(description="目标 URL")
|
||||||
|
payload: str = Field(description="攻击载荷")
|
||||||
|
expected_pattern: Optional[str] = Field(default=None, description="期望在响应中匹配的正则模式")
|
||||||
|
|
||||||
|
|
||||||
|
class VulnerabilityVerifyTool(AgentTool):
|
||||||
|
"""
|
||||||
|
漏洞验证工具
|
||||||
|
在沙箱中验证漏洞是否真实存在
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sandbox_manager: Optional[SandboxManager] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.sandbox_manager = sandbox_manager or SandboxManager()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "verify_vulnerability"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return """验证漏洞是否真实存在。
|
||||||
|
发送包含攻击载荷的请求,分析响应判断漏洞是否可利用。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- vulnerability_type: 漏洞类型
|
||||||
|
- target_url: 目标 URL
|
||||||
|
- payload: 攻击载荷
|
||||||
|
- expected_pattern: 可选,期望在响应中匹配的模式
|
||||||
|
|
||||||
|
支持的漏洞类型:
|
||||||
|
- sql_injection: SQL 注入
|
||||||
|
- xss: 跨站脚本
|
||||||
|
- command_injection: 命令注入
|
||||||
|
- path_traversal: 路径遍历
|
||||||
|
- ssrf: 服务端请求伪造"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_schema(self):
|
||||||
|
return VulnerabilityVerifyInput
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
vulnerability_type: str,
|
||||||
|
target_url: str,
|
||||||
|
payload: str,
|
||||||
|
expected_pattern: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ToolResult:
|
||||||
|
"""执行漏洞验证"""
|
||||||
|
await self.sandbox_manager.initialize()
|
||||||
|
|
||||||
|
if not self.sandbox_manager.is_available:
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error="沙箱环境不可用",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.sandbox_manager.verify_vulnerability(
|
||||||
|
vulnerability_type=vulnerability_type,
|
||||||
|
target_url=target_url,
|
||||||
|
payload=payload,
|
||||||
|
expected_pattern=expected_pattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_parts = ["🔍 漏洞验证结果\n"]
|
||||||
|
output_parts.append(f"漏洞类型: {vulnerability_type}")
|
||||||
|
output_parts.append(f"目标: {target_url}")
|
||||||
|
output_parts.append(f"Payload: {payload[:200]}")
|
||||||
|
|
||||||
|
if result["is_vulnerable"]:
|
||||||
|
output_parts.append(f"\n🔴 结果: 漏洞已确认!")
|
||||||
|
output_parts.append(f"证据: {result.get('evidence', 'N/A')}")
|
||||||
|
else:
|
||||||
|
output_parts.append(f"\n🟢 结果: 未能确认漏洞")
|
||||||
|
if result.get("error"):
|
||||||
|
output_parts.append(f"错误: {result['error']}")
|
||||||
|
|
||||||
|
if result.get("response_status"):
|
||||||
|
output_parts.append(f"\nHTTP 状态码: {result['response_status']}")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data="\n".join(output_parts),
|
||||||
|
metadata={
|
||||||
|
"vulnerability_type": vulnerability_type,
|
||||||
|
"is_vulnerable": result["is_vulnerable"],
|
||||||
|
"evidence": result.get("evidence"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -144,3 +144,4 @@ class BaiduAdapter(BaseLLMAdapter):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,3 +82,4 @@ class DoubaoAdapter(BaseLLMAdapter):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -85,3 +85,4 @@ class MinimaxAdapter(BaseLLMAdapter):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,3 +133,4 @@ class BaseLLMAdapter(ABC):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,3 +119,4 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""
|
||||||
|
RAG (Retrieval-Augmented Generation) 系统
|
||||||
|
用于代码索引和语义检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .splitter import CodeSplitter, CodeChunk
|
||||||
|
from .embeddings import EmbeddingService
|
||||||
|
from .indexer import CodeIndexer
|
||||||
|
from .retriever import CodeRetriever
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CodeSplitter",
|
||||||
|
"CodeChunk",
|
||||||
|
"EmbeddingService",
|
||||||
|
"CodeIndexer",
|
||||||
|
"CodeRetriever",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,658 @@
|
||||||
|
"""
|
||||||
|
嵌入模型服务
|
||||||
|
支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingResult:
|
||||||
|
"""嵌入结果"""
|
||||||
|
embedding: List[float]
|
||||||
|
tokens_used: int
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
"""嵌入提供商基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
"""嵌入单个文本"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
"""批量嵌入文本"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dimension(self) -> int:
|
||||||
|
"""嵌入向量维度"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbedding(EmbeddingProvider):
|
||||||
|
"""OpenAI 嵌入服务"""
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"text-embedding-3-small": 1536,
|
||||||
|
"text-embedding-3-large": 3072,
|
||||||
|
"text-embedding-ada-002": 1536,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
):
|
||||||
|
self.api_key = api_key or settings.LLM_API_KEY
|
||||||
|
self.base_url = base_url or "https://api.openai.com/v1"
|
||||||
|
self.model = model
|
||||||
|
self._dimension = self.MODELS.get(model, 1536)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
results = await self.embed_texts([text])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
max_length = 8191
|
||||||
|
truncated_texts = [text[:max_length] for text in texts]
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": truncated_texts,
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get("data", []):
|
||||||
|
results.append(EmbeddingResult(
|
||||||
|
embedding=item["embedding"],
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIEmbedding(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
Azure OpenAI 嵌入服务
|
||||||
|
|
||||||
|
使用最新 API 版本 2024-10-21 (GA)
|
||||||
|
端点格式: https://<resource>.openai.azure.com/openai/deployments/<deployment>/embeddings?api-version=2024-10-21
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"text-embedding-3-small": 1536,
|
||||||
|
"text-embedding-3-large": 3072,
|
||||||
|
"text-embedding-ada-002": 1536,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 最新的 GA API 版本
|
||||||
|
API_VERSION = "2024-10-21"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url or "https://your-resource.openai.azure.com"
|
||||||
|
self.model = model
|
||||||
|
self._dimension = self.MODELS.get(model, 1536)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
results = await self.embed_texts([text])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
max_length = 8191
|
||||||
|
truncated_texts = [text[:max_length] for text in texts]
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"api-key": self.api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input": truncated_texts,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Azure URL 格式 - 使用最新 API 版本
|
||||||
|
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get("data", []):
|
||||||
|
results.append(EmbeddingResult(
|
||||||
|
embedding=item["embedding"],
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedding(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
Ollama 本地嵌入服务
|
||||||
|
|
||||||
|
使用新的 /api/embed 端点 (2024年起):
|
||||||
|
- 支持批量嵌入
|
||||||
|
- 使用 'input' 参数(支持字符串或字符串数组)
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"nomic-embed-text": 768,
|
||||||
|
"mxbai-embed-large": 1024,
|
||||||
|
"all-minilm": 384,
|
||||||
|
"snowflake-arctic-embed": 1024,
|
||||||
|
"bge-m3": 1024,
|
||||||
|
"qwen3-embedding": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: str = "nomic-embed-text",
|
||||||
|
):
|
||||||
|
self.base_url = base_url or "http://localhost:11434"
|
||||||
|
self.model = model
|
||||||
|
self._dimension = self.MODELS.get(model, 768)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
results = await self.embed_texts([text])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 新的 Ollama /api/embed 端点
|
||||||
|
url = f"{self.base_url.rstrip('/')}/api/embed"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": texts, # 新 API 使用 'input' 参数,支持批量
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
response = await client.post(url, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
|
||||||
|
embeddings = data.get("embeddings", [])
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, embedding in enumerate(embeddings):
|
||||||
|
results.append(EmbeddingResult(
|
||||||
|
embedding=embedding,
|
||||||
|
tokens_used=len(texts[i]) // 4,
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbedding(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
Cohere 嵌入服务
|
||||||
|
|
||||||
|
使用新的 v2 API (2024年起):
|
||||||
|
- 端点: https://api.cohere.com/v2/embed
|
||||||
|
- 使用 'inputs' 参数替代 'texts'
|
||||||
|
- 需要指定 'embedding_types'
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"embed-english-v3.0": 1024,
|
||||||
|
"embed-multilingual-v3.0": 1024,
|
||||||
|
"embed-english-light-v3.0": 384,
|
||||||
|
"embed-multilingual-light-v3.0": 384,
|
||||||
|
"embed-v4.0": 1024, # 最新模型
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: str = "embed-multilingual-v3.0",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
# 新的 v2 API 端点
|
||||||
|
self.base_url = base_url or "https://api.cohere.com/v2"
|
||||||
|
self.model = model
|
||||||
|
self._dimension = self.MODELS.get(model, 1024)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
results = await self.embed_texts([text])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# v2 API 参数格式
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"inputs": texts, # v2 使用 'inputs' 而非 'texts'
|
||||||
|
"input_type": "search_document",
|
||||||
|
"embedding_types": ["float"], # v2 需要指定嵌入类型
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url.rstrip('/')}/embed"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
|
||||||
|
embeddings_data = data.get("embeddings", {})
|
||||||
|
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
|
||||||
|
|
||||||
|
for embedding in embeddings:
|
||||||
|
results.append(EmbeddingResult(
|
||||||
|
embedding=embedding,
|
||||||
|
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEmbedding(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
HuggingFace Inference Providers 嵌入服务
|
||||||
|
|
||||||
|
使用新的 Router 端点 (2025年起):
|
||||||
|
https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
||||||
|
"sentence-transformers/all-mpnet-base-v2": 768,
|
||||||
|
"BAAI/bge-large-zh-v1.5": 1024,
|
||||||
|
"BAAI/bge-m3": 1024,
|
||||||
|
"BAAI/bge-small-en-v1.5": 384,
|
||||||
|
"BAAI/bge-base-en-v1.5": 768,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: str = "BAAI/bge-m3",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
# 新的 Router 端点
|
||||||
|
self.base_url = base_url or "https://router.huggingface.co"
|
||||||
|
self.model = model
|
||||||
|
self._dimension = self.MODELS.get(model, 1024)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
results = await self.embed_texts([text])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 新的 HuggingFace Router URL 格式
|
||||||
|
# https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
|
||||||
|
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs": texts,
|
||||||
|
"options": {
|
||||||
|
"wait_for_model": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
# HuggingFace 返回格式: [[embedding1], [embedding2], ...]
|
||||||
|
for embedding in data:
|
||||||
|
# 有时候返回的是嵌套的列表
|
||||||
|
if isinstance(embedding, list) and len(embedding) > 0:
|
||||||
|
if isinstance(embedding[0], list):
|
||||||
|
# 取平均或第一个
|
||||||
|
embedding = embedding[0]
|
||||||
|
|
||||||
|
results.append(EmbeddingResult(
|
||||||
|
embedding=embedding,
|
||||||
|
tokens_used=len(texts[len(results)]) // 4,
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class JinaEmbedding(EmbeddingProvider):
|
||||||
|
"""Jina AI 嵌入服务"""
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"jina-embeddings-v2-base-code": 768,
|
||||||
|
"jina-embeddings-v2-base-en": 768,
|
||||||
|
"jina-embeddings-v2-base-zh": 768,
|
||||||
|
"jina-embeddings-v2-small-en": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: str = "jina-embeddings-v2-base-code",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url or "https://api.jina.ai/v1"
|
||||||
|
self.model = model
|
||||||
|
self._dimension = self.MODELS.get(model, 768)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||||
|
results = await self.embed_texts([text])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": texts,
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get("data", []):
|
||||||
|
results.append(EmbeddingResult(
|
||||||
|
embedding=item["embedding"],
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
"""
|
||||||
|
嵌入服务
|
||||||
|
统一管理嵌入模型和缓存
|
||||||
|
|
||||||
|
支持的提供商:
|
||||||
|
- openai: OpenAI 官方
|
||||||
|
- azure: Azure OpenAI
|
||||||
|
- ollama: Ollama 本地
|
||||||
|
- cohere: Cohere
|
||||||
|
- huggingface: HuggingFace Inference API
|
||||||
|
- jina: Jina AI
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
cache_enabled: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化嵌入服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina)
|
||||||
|
model: 模型名称
|
||||||
|
api_key: API Key
|
||||||
|
base_url: API Base URL
|
||||||
|
cache_enabled: 是否启用缓存
|
||||||
|
"""
|
||||||
|
self.cache_enabled = cache_enabled
|
||||||
|
self._cache: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
|
# 确定提供商
|
||||||
|
provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||||
|
model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||||
|
|
||||||
|
# 创建提供商实例
|
||||||
|
self._provider = self._create_provider(
|
||||||
|
provider=provider,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Embedding service initialized with {provider}/{model}")
|
||||||
|
|
||||||
|
def _create_provider(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
base_url: Optional[str],
|
||||||
|
) -> EmbeddingProvider:
|
||||||
|
"""创建嵌入提供商实例"""
|
||||||
|
provider = provider.lower()
|
||||||
|
|
||||||
|
if provider == "ollama":
|
||||||
|
return OllamaEmbedding(base_url=base_url, model=model)
|
||||||
|
|
||||||
|
elif provider == "azure":
|
||||||
|
return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||||
|
|
||||||
|
elif provider == "cohere":
|
||||||
|
return CohereEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||||
|
|
||||||
|
elif provider == "huggingface":
|
||||||
|
return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||||
|
|
||||||
|
elif provider == "jina":
|
||||||
|
return JinaEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 默认使用 OpenAI
|
||||||
|
return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
"""嵌入向量维度"""
|
||||||
|
return self._provider.dimension
|
||||||
|
|
||||||
|
def _cache_key(self, text: str) -> str:
|
||||||
|
"""生成缓存键"""
|
||||||
|
return hashlib.sha256(text.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
async def embed(self, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
嵌入单个文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return [0.0] * self.dimension
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if self.cache_enabled:
|
||||||
|
cache_key = self._cache_key(text)
|
||||||
|
if cache_key in self._cache:
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
result = await self._provider.embed_text(text)
|
||||||
|
|
||||||
|
# 存入缓存
|
||||||
|
if self.cache_enabled:
|
||||||
|
self._cache[cache_key] = result.embedding
|
||||||
|
|
||||||
|
return result.embedding
|
||||||
|
|
||||||
|
async def embed_batch(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
batch_size: int = 100,
|
||||||
|
show_progress: bool = False,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
批量嵌入文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
batch_size: 批次大小
|
||||||
|
show_progress: 是否显示进度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量列表
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
uncached_indices = []
|
||||||
|
uncached_texts = []
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
if not text or not text.strip():
|
||||||
|
embeddings.append([0.0] * self.dimension)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.cache_enabled:
|
||||||
|
cache_key = self._cache_key(text)
|
||||||
|
if cache_key in self._cache:
|
||||||
|
embeddings.append(self._cache[cache_key])
|
||||||
|
continue
|
||||||
|
|
||||||
|
embeddings.append(None) # 占位
|
||||||
|
uncached_indices.append(i)
|
||||||
|
uncached_texts.append(text)
|
||||||
|
|
||||||
|
# 批量处理未缓存的文本
|
||||||
|
if uncached_texts:
|
||||||
|
for i in range(0, len(uncached_texts), batch_size):
|
||||||
|
batch = uncached_texts[i:i + batch_size]
|
||||||
|
batch_indices = uncached_indices[i:i + batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await self._provider.embed_texts(batch)
|
||||||
|
|
||||||
|
for idx, result in zip(batch_indices, results):
|
||||||
|
embeddings[idx] = result.embedding
|
||||||
|
|
||||||
|
# 存入缓存
|
||||||
|
if self.cache_enabled:
|
||||||
|
cache_key = self._cache_key(texts[idx])
|
||||||
|
self._cache[cache_key] = result.embedding
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch embedding error: {e}")
|
||||||
|
# 对失败的使用零向量
|
||||||
|
for idx in batch_indices:
|
||||||
|
if embeddings[idx] is None:
|
||||||
|
embeddings[idx] = [0.0] * self.dimension
|
||||||
|
|
||||||
|
# 添加小延迟避免限流
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# 确保没有 None
|
||||||
|
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""清空缓存"""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
"""缓存大小"""
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,585 @@
|
||||||
|
"""
|
||||||
|
代码索引器
|
||||||
|
将代码分块并索引到向量数据库
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .splitter import CodeSplitter, CodeChunk
|
||||||
|
from .embeddings import EmbeddingService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 支持的文本文件扩展名
|
||||||
|
TEXT_EXTENSIONS = {
|
||||||
|
".py", ".js", ".ts", ".tsx", ".jsx", ".java", ".go", ".rs",
|
||||||
|
".cpp", ".c", ".h", ".cc", ".hh", ".cs", ".php", ".rb",
|
||||||
|
".kt", ".swift", ".sql", ".sh", ".json", ".yml", ".yaml",
|
||||||
|
".xml", ".html", ".css", ".vue", ".svelte", ".md",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 排除的目录
|
||||||
|
EXCLUDE_DIRS = {
|
||||||
|
"node_modules", "vendor", "dist", "build", ".git",
|
||||||
|
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
|
||||||
|
".vscode", ".idea", ".vs", "target", "out", "bin", "obj",
|
||||||
|
"__MACOSX", ".next", ".nuxt", "venv", "env", ".env",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 排除的文件
|
||||||
|
EXCLUDE_FILES = {
|
||||||
|
".DS_Store", "package-lock.json", "yarn.lock", "pnpm-lock.yaml",
|
||||||
|
"Cargo.lock", "poetry.lock", "composer.lock", "Gemfile.lock",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexingProgress:
|
||||||
|
"""索引进度"""
|
||||||
|
total_files: int = 0
|
||||||
|
processed_files: int = 0
|
||||||
|
total_chunks: int = 0
|
||||||
|
indexed_chunks: int = 0
|
||||||
|
current_file: str = ""
|
||||||
|
errors: List[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.errors is None:
|
||||||
|
self.errors = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress_percentage(self) -> float:
|
||||||
|
if self.total_files == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.processed_files / self.total_files) * 100
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexingResult:
|
||||||
|
"""索引结果"""
|
||||||
|
success: bool
|
||||||
|
total_files: int
|
||||||
|
indexed_files: int
|
||||||
|
total_chunks: int
|
||||||
|
errors: List[str]
|
||||||
|
collection_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""向量存储抽象基类"""
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化存储"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def add_documents(
|
||||||
|
self,
|
||||||
|
ids: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
documents: List[str],
|
||||||
|
metadatas: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
|
"""添加文档"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
n_results: int = 10,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""查询"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_collection(self):
|
||||||
|
"""删除集合"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_count(self) -> int:
|
||||||
|
"""获取文档数量"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaVectorStore(VectorStore):
|
||||||
|
"""Chroma 向量存储"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.persist_directory = persist_directory
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化 Chroma"""
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
if self.persist_directory:
|
||||||
|
self._client = chromadb.PersistentClient(
|
||||||
|
path=self.persist_directory,
|
||||||
|
settings=Settings(anonymized_telemetry=False),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._client = chromadb.Client(
|
||||||
|
settings=Settings(anonymized_telemetry=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection = self._client.get_or_create_collection(
|
||||||
|
name=self.collection_name,
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Chroma collection '{self.collection_name}' initialized")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("chromadb is required. Install with: pip install chromadb")
|
||||||
|
|
||||||
|
async def add_documents(
|
||||||
|
self,
|
||||||
|
ids: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
documents: List[str],
|
||||||
|
metadatas: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
|
"""添加文档到 Chroma"""
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Chroma 对元数据有限制,需要清理
|
||||||
|
cleaned_metadatas = []
|
||||||
|
for meta in metadatas:
|
||||||
|
cleaned = {}
|
||||||
|
for k, v in meta.items():
|
||||||
|
if isinstance(v, (str, int, float, bool)):
|
||||||
|
cleaned[k] = v
|
||||||
|
elif isinstance(v, list):
|
||||||
|
# 列表转为 JSON 字符串
|
||||||
|
cleaned[k] = json.dumps(v)
|
||||||
|
elif v is not None:
|
||||||
|
cleaned[k] = str(v)
|
||||||
|
cleaned_metadatas.append(cleaned)
|
||||||
|
|
||||||
|
# 分批添加(Chroma 批次限制)
|
||||||
|
batch_size = 500
|
||||||
|
for i in range(0, len(ids), batch_size):
|
||||||
|
batch_ids = ids[i:i + batch_size]
|
||||||
|
batch_embeddings = embeddings[i:i + batch_size]
|
||||||
|
batch_documents = documents[i:i + batch_size]
|
||||||
|
batch_metadatas = cleaned_metadatas[i:i + batch_size]
|
||||||
|
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self._collection.add,
|
||||||
|
ids=batch_ids,
|
||||||
|
embeddings=batch_embeddings,
|
||||||
|
documents=batch_documents,
|
||||||
|
metadatas=batch_metadatas,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
n_results: int = 10,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""查询 Chroma"""
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
self._collection.query,
|
||||||
|
query_embeddings=[query_embedding],
|
||||||
|
n_results=n_results,
|
||||||
|
where=where,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ids": result["ids"][0] if result["ids"] else [],
|
||||||
|
"documents": result["documents"][0] if result["documents"] else [],
|
||||||
|
"metadatas": result["metadatas"][0] if result["metadatas"] else [],
|
||||||
|
"distances": result["distances"][0] if result["distances"] else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def delete_collection(self):
|
||||||
|
"""删除集合"""
|
||||||
|
if self._client and self._collection:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self._client.delete_collection,
|
||||||
|
name=self.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_count(self) -> int:
|
||||||
|
"""获取文档数量"""
|
||||||
|
if self._collection:
|
||||||
|
return await asyncio.to_thread(self._collection.count)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryVectorStore(VectorStore):
|
||||||
|
"""内存向量存储(用于测试或小项目)"""
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str):
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self._documents: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化"""
|
||||||
|
logger.info(f"InMemory vector store '{self.collection_name}' initialized")
|
||||||
|
|
||||||
|
async def add_documents(
|
||||||
|
self,
|
||||||
|
ids: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
documents: List[str],
|
||||||
|
metadatas: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
|
"""添加文档"""
|
||||||
|
for id_, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
|
||||||
|
self._documents[id_] = {
|
||||||
|
"embedding": emb,
|
||||||
|
"document": doc,
|
||||||
|
"metadata": meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
n_results: int = 10,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""查询(使用余弦相似度)"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
||||||
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||||||
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||||||
|
if norm_a == 0 or norm_b == 0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for id_, data in self._documents.items():
|
||||||
|
# 应用过滤条件
|
||||||
|
if where:
|
||||||
|
match = True
|
||||||
|
for k, v in where.items():
|
||||||
|
if data["metadata"].get(k) != v:
|
||||||
|
match = False
|
||||||
|
break
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
similarity = cosine_similarity(query_embedding, data["embedding"])
|
||||||
|
results.append({
|
||||||
|
"id": id_,
|
||||||
|
"document": data["document"],
|
||||||
|
"metadata": data["metadata"],
|
||||||
|
"distance": 1 - similarity, # 转换为距离
|
||||||
|
})
|
||||||
|
|
||||||
|
# 按距离排序
|
||||||
|
results.sort(key=lambda x: x["distance"])
|
||||||
|
results = results[:n_results]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ids": [r["id"] for r in results],
|
||||||
|
"documents": [r["document"] for r in results],
|
||||||
|
"metadatas": [r["metadata"] for r in results],
|
||||||
|
"distances": [r["distance"] for r in results],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def delete_collection(self):
|
||||||
|
"""删除集合"""
|
||||||
|
self._documents.clear()
|
||||||
|
|
||||||
|
async def get_count(self) -> int:
|
||||||
|
"""获取文档数量"""
|
||||||
|
return len(self._documents)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeIndexer:
|
||||||
|
"""
|
||||||
|
代码索引器
|
||||||
|
将代码文件分块、嵌入并索引到向量数据库
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
embedding_service: Optional[EmbeddingService] = None,
|
||||||
|
vector_store: Optional[VectorStore] = None,
|
||||||
|
splitter: Optional[CodeSplitter] = None,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化索引器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: 向量集合名称
|
||||||
|
embedding_service: 嵌入服务
|
||||||
|
vector_store: 向量存储
|
||||||
|
splitter: 代码分块器
|
||||||
|
persist_directory: 持久化目录
|
||||||
|
"""
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.embedding_service = embedding_service or EmbeddingService()
|
||||||
|
self.splitter = splitter or CodeSplitter()
|
||||||
|
|
||||||
|
# 创建向量存储
|
||||||
|
if vector_store:
|
||||||
|
self.vector_store = vector_store
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.vector_store = ChromaVectorStore(
|
||||||
|
collection_name=collection_name,
|
||||||
|
persist_directory=persist_directory,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Chroma not available, using in-memory store")
|
||||||
|
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
|
||||||
|
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化索引器"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.vector_store.initialize()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
async def index_directory(
|
||||||
|
self,
|
||||||
|
directory: str,
|
||||||
|
exclude_patterns: Optional[List[str]] = None,
|
||||||
|
include_patterns: Optional[List[str]] = None,
|
||||||
|
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
|
||||||
|
) -> AsyncGenerator[IndexingProgress, None]:
|
||||||
|
"""
|
||||||
|
索引目录中的代码文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: 目录路径
|
||||||
|
exclude_patterns: 排除模式
|
||||||
|
include_patterns: 包含模式
|
||||||
|
progress_callback: 进度回调
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
索引进度
|
||||||
|
"""
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
progress = IndexingProgress()
|
||||||
|
exclude_patterns = exclude_patterns or []
|
||||||
|
|
||||||
|
# 收集文件
|
||||||
|
files = self._collect_files(directory, exclude_patterns, include_patterns)
|
||||||
|
progress.total_files = len(files)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(files)} files to index in {directory}")
|
||||||
|
yield progress
|
||||||
|
|
||||||
|
all_chunks: List[CodeChunk] = []
|
||||||
|
|
||||||
|
# 分块处理文件
|
||||||
|
for file_path in files:
|
||||||
|
progress.current_file = file_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
relative_path = os.path.relpath(file_path, directory)
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
if not content.strip():
|
||||||
|
progress.processed_files += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 限制文件大小
|
||||||
|
if len(content) > 500000: # 500KB
|
||||||
|
content = content[:500000]
|
||||||
|
|
||||||
|
# 分块
|
||||||
|
chunks = self.splitter.split_file(content, relative_path)
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
|
||||||
|
progress.processed_files += 1
|
||||||
|
progress.total_chunks = len(all_chunks)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(progress)
|
||||||
|
yield progress
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing {file_path}: {e}")
|
||||||
|
progress.errors.append(f"{file_path}: {str(e)}")
|
||||||
|
progress.processed_files += 1
|
||||||
|
|
||||||
|
logger.info(f"Created {len(all_chunks)} chunks from {len(files)} files")
|
||||||
|
|
||||||
|
# 批量嵌入和索引
|
||||||
|
if all_chunks:
|
||||||
|
await self._index_chunks(all_chunks, progress)
|
||||||
|
|
||||||
|
progress.indexed_chunks = len(all_chunks)
|
||||||
|
yield progress
|
||||||
|
|
||||||
|
async def index_files(
|
||||||
|
self,
|
||||||
|
files: List[Dict[str, str]],
|
||||||
|
base_path: str = "",
|
||||||
|
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
|
||||||
|
) -> AsyncGenerator[IndexingProgress, None]:
|
||||||
|
"""
|
||||||
|
索引文件列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 文件列表 [{"path": "...", "content": "..."}]
|
||||||
|
base_path: 基础路径
|
||||||
|
progress_callback: 进度回调
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
索引进度
|
||||||
|
"""
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
progress = IndexingProgress()
|
||||||
|
progress.total_files = len(files)
|
||||||
|
|
||||||
|
all_chunks: List[CodeChunk] = []
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
file_path = file_info.get("path", "")
|
||||||
|
content = file_info.get("content", "")
|
||||||
|
|
||||||
|
progress.current_file = file_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not content.strip():
|
||||||
|
progress.processed_files += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 限制文件大小
|
||||||
|
if len(content) > 500000:
|
||||||
|
content = content[:500000]
|
||||||
|
|
||||||
|
# 分块
|
||||||
|
chunks = self.splitter.split_file(content, file_path)
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
|
||||||
|
progress.processed_files += 1
|
||||||
|
progress.total_chunks = len(all_chunks)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(progress)
|
||||||
|
yield progress
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing {file_path}: {e}")
|
||||||
|
progress.errors.append(f"{file_path}: {str(e)}")
|
||||||
|
progress.processed_files += 1
|
||||||
|
|
||||||
|
# 批量嵌入和索引
|
||||||
|
if all_chunks:
|
||||||
|
await self._index_chunks(all_chunks, progress)
|
||||||
|
|
||||||
|
progress.indexed_chunks = len(all_chunks)
|
||||||
|
yield progress
|
||||||
|
|
||||||
|
async def _index_chunks(self, chunks: List[CodeChunk], progress: IndexingProgress):
|
||||||
|
"""索引代码块"""
|
||||||
|
# 准备嵌入文本
|
||||||
|
texts = [chunk.to_embedding_text() for chunk in chunks]
|
||||||
|
|
||||||
|
logger.info(f"Generating embeddings for {len(texts)} chunks...")
|
||||||
|
|
||||||
|
# 批量嵌入
|
||||||
|
embeddings = await self.embedding_service.embed_batch(texts, batch_size=50)
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
ids = [chunk.id for chunk in chunks]
|
||||||
|
documents = [chunk.content for chunk in chunks]
|
||||||
|
metadatas = [chunk.to_dict() for chunk in chunks]
|
||||||
|
|
||||||
|
# 添加到向量存储
|
||||||
|
logger.info(f"Adding {len(chunks)} chunks to vector store...")
|
||||||
|
await self.vector_store.add_documents(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=documents,
|
||||||
|
metadatas=metadatas,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Indexed {len(chunks)} chunks successfully")
|
||||||
|
|
||||||
|
def _collect_files(
|
||||||
|
self,
|
||||||
|
directory: str,
|
||||||
|
exclude_patterns: List[str],
|
||||||
|
include_patterns: Optional[List[str]],
|
||||||
|
) -> List[str]:
|
||||||
|
"""收集需要索引的文件"""
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
files = []
|
||||||
|
|
||||||
|
for root, dirs, filenames in os.walk(directory):
|
||||||
|
# 过滤目录
|
||||||
|
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
# 检查扩展名
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if ext not in TEXT_EXTENSIONS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查排除文件
|
||||||
|
if filename in EXCLUDE_FILES:
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path = os.path.join(root, filename)
|
||||||
|
relative_path = os.path.relpath(file_path, directory)
|
||||||
|
|
||||||
|
# 检查排除模式
|
||||||
|
excluded = False
|
||||||
|
for pattern in exclude_patterns:
|
||||||
|
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||||||
|
excluded = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if excluded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查包含模式
|
||||||
|
if include_patterns:
|
||||||
|
included = False
|
||||||
|
for pattern in include_patterns:
|
||||||
|
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||||||
|
included = True
|
||||||
|
break
|
||||||
|
if not included:
|
||||||
|
continue
|
||||||
|
|
||||||
|
files.append(file_path)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
async def get_chunk_count(self) -> int:
|
||||||
|
"""获取已索引的代码块数量"""
|
||||||
|
await self.initialize()
|
||||||
|
return await self.vector_store.get_count()
|
||||||
|
|
||||||
|
async def clear(self):
|
||||||
|
"""清空索引"""
|
||||||
|
await self.initialize()
|
||||||
|
await self.vector_store.delete_collection()
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
|
@ -0,0 +1,469 @@
|
||||||
|
"""
|
||||||
|
代码检索器
|
||||||
|
支持语义检索和混合检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .embeddings import EmbeddingService
|
||||||
|
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
|
||||||
|
from .splitter import CodeChunk, ChunkType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalResult:
|
||||||
|
"""检索结果"""
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
file_path: str
|
||||||
|
language: str
|
||||||
|
chunk_type: str
|
||||||
|
line_start: int
|
||||||
|
line_end: int
|
||||||
|
score: float # 相似度分数 (0-1, 越高越相似)
|
||||||
|
|
||||||
|
# 可选的元数据
|
||||||
|
name: Optional[str] = None
|
||||||
|
parent_name: Optional[str] = None
|
||||||
|
signature: Optional[str] = None
|
||||||
|
security_indicators: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 原始元数据
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"chunk_id": self.chunk_id,
|
||||||
|
"content": self.content,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"language": self.language,
|
||||||
|
"chunk_type": self.chunk_type,
|
||||||
|
"line_start": self.line_start,
|
||||||
|
"line_end": self.line_end,
|
||||||
|
"score": self.score,
|
||||||
|
"name": self.name,
|
||||||
|
"parent_name": self.parent_name,
|
||||||
|
"signature": self.signature,
|
||||||
|
"security_indicators": self.security_indicators,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_context_string(self, include_metadata: bool = True) -> str:
|
||||||
|
"""转换为上下文字符串(用于 LLM 输入)"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if include_metadata:
|
||||||
|
header = f"File: {self.file_path}"
|
||||||
|
if self.line_start and self.line_end:
|
||||||
|
header += f" (lines {self.line_start}-{self.line_end})"
|
||||||
|
if self.name:
|
||||||
|
header += f"\n{self.chunk_type.title()}: {self.name}"
|
||||||
|
if self.parent_name:
|
||||||
|
header += f" in {self.parent_name}"
|
||||||
|
parts.append(header)
|
||||||
|
|
||||||
|
parts.append(f"```{self.language}\n{self.content}\n```")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRetriever:
|
||||||
|
"""
|
||||||
|
代码检索器
|
||||||
|
支持语义检索、关键字检索和混合检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
embedding_service: Optional[EmbeddingService] = None,
|
||||||
|
vector_store: Optional[VectorStore] = None,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化检索器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: 向量集合名称
|
||||||
|
embedding_service: 嵌入服务
|
||||||
|
vector_store: 向量存储
|
||||||
|
persist_directory: 持久化目录
|
||||||
|
"""
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.embedding_service = embedding_service or EmbeddingService()
|
||||||
|
|
||||||
|
# 创建向量存储
|
||||||
|
if vector_store:
|
||||||
|
self.vector_store = vector_store
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.vector_store = ChromaVectorStore(
|
||||||
|
collection_name=collection_name,
|
||||||
|
persist_directory=persist_directory,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Chroma not available, using in-memory store")
|
||||||
|
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
|
||||||
|
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化检索器"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.vector_store.initialize()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
filter_file_path: Optional[str] = None,
|
||||||
|
filter_language: Optional[str] = None,
|
||||||
|
filter_chunk_type: Optional[str] = None,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
"""
|
||||||
|
语义检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
top_k: 返回数量
|
||||||
|
filter_file_path: 文件路径过滤
|
||||||
|
filter_language: 语言过滤
|
||||||
|
filter_chunk_type: 块类型过滤
|
||||||
|
min_score: 最小相似度分数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表
|
||||||
|
"""
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
# 生成查询嵌入
|
||||||
|
query_embedding = await self.embedding_service.embed(query)
|
||||||
|
|
||||||
|
# 构建过滤条件
|
||||||
|
where = {}
|
||||||
|
if filter_file_path:
|
||||||
|
where["file_path"] = filter_file_path
|
||||||
|
if filter_language:
|
||||||
|
where["language"] = filter_language
|
||||||
|
if filter_chunk_type:
|
||||||
|
where["chunk_type"] = filter_chunk_type
|
||||||
|
|
||||||
|
# 查询向量存储
|
||||||
|
raw_results = await self.vector_store.query(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
n_results=top_k * 2, # 多查一些,后面过滤
|
||||||
|
where=where if where else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换结果
|
||||||
|
results = []
|
||||||
|
for i, (id_, doc, meta, dist) in enumerate(zip(
|
||||||
|
raw_results["ids"],
|
||||||
|
raw_results["documents"],
|
||||||
|
raw_results["metadatas"],
|
||||||
|
raw_results["distances"],
|
||||||
|
)):
|
||||||
|
# 将距离转换为相似度分数 (余弦距离)
|
||||||
|
score = 1 - dist
|
||||||
|
|
||||||
|
if score < min_score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析安全指标(可能是 JSON 字符串)
|
||||||
|
security_indicators = meta.get("security_indicators", [])
|
||||||
|
if isinstance(security_indicators, str):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
security_indicators = json.loads(security_indicators)
|
||||||
|
except:
|
||||||
|
security_indicators = []
|
||||||
|
|
||||||
|
result = RetrievalResult(
|
||||||
|
chunk_id=id_,
|
||||||
|
content=doc,
|
||||||
|
file_path=meta.get("file_path", ""),
|
||||||
|
language=meta.get("language", "text"),
|
||||||
|
chunk_type=meta.get("chunk_type", "unknown"),
|
||||||
|
line_start=meta.get("line_start", 0),
|
||||||
|
line_end=meta.get("line_end", 0),
|
||||||
|
score=score,
|
||||||
|
name=meta.get("name"),
|
||||||
|
parent_name=meta.get("parent_name"),
|
||||||
|
signature=meta.get("signature"),
|
||||||
|
security_indicators=security_indicators,
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# 按分数排序并截取
|
||||||
|
results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
async def retrieve_by_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
"""
|
||||||
|
按文件路径检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
top_k: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该文件的所有代码块
|
||||||
|
"""
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
# 使用一个通用查询
|
||||||
|
query_embedding = await self.embedding_service.embed(f"code in {file_path}")
|
||||||
|
|
||||||
|
raw_results = await self.vector_store.query(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
n_results=top_k,
|
||||||
|
where={"file_path": file_path},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for id_, doc, meta, dist in zip(
|
||||||
|
raw_results["ids"],
|
||||||
|
raw_results["documents"],
|
||||||
|
raw_results["metadatas"],
|
||||||
|
raw_results["distances"],
|
||||||
|
):
|
||||||
|
result = RetrievalResult(
|
||||||
|
chunk_id=id_,
|
||||||
|
content=doc,
|
||||||
|
file_path=meta.get("file_path", ""),
|
||||||
|
language=meta.get("language", "text"),
|
||||||
|
chunk_type=meta.get("chunk_type", "unknown"),
|
||||||
|
line_start=meta.get("line_start", 0),
|
||||||
|
line_end=meta.get("line_end", 0),
|
||||||
|
score=1 - dist,
|
||||||
|
name=meta.get("name"),
|
||||||
|
parent_name=meta.get("parent_name"),
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# 按行号排序
|
||||||
|
results.sort(key=lambda x: x.line_start)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def retrieve_security_related(
|
||||||
|
self,
|
||||||
|
vulnerability_type: Optional[str] = None,
|
||||||
|
top_k: int = 20,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
"""
|
||||||
|
检索与安全相关的代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vulnerability_type: 漏洞类型(如 sql_injection, xss 等)
|
||||||
|
top_k: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全相关的代码块
|
||||||
|
"""
|
||||||
|
# 根据漏洞类型构建查询
|
||||||
|
security_queries = {
|
||||||
|
"sql_injection": "SQL query execute database user input",
|
||||||
|
"xss": "HTML render user input innerHTML template",
|
||||||
|
"command_injection": "system exec command shell subprocess",
|
||||||
|
"path_traversal": "file path read open user input",
|
||||||
|
"ssrf": "HTTP request URL user input fetch",
|
||||||
|
"deserialization": "deserialize pickle yaml load object",
|
||||||
|
"auth_bypass": "authentication login password token session",
|
||||||
|
"hardcoded_secret": "password secret key token credential",
|
||||||
|
}
|
||||||
|
|
||||||
|
if vulnerability_type and vulnerability_type in security_queries:
|
||||||
|
query = security_queries[vulnerability_type]
|
||||||
|
else:
|
||||||
|
query = "security vulnerability dangerous function user input"
|
||||||
|
|
||||||
|
return await self.retrieve(query, top_k=top_k)
|
||||||
|
|
||||||
|
async def retrieve_function_context(
|
||||||
|
self,
|
||||||
|
function_name: str,
|
||||||
|
file_path: Optional[str] = None,
|
||||||
|
include_callers: bool = True,
|
||||||
|
include_callees: bool = True,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> Dict[str, List[RetrievalResult]]:
|
||||||
|
"""
|
||||||
|
检索函数上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_name: 函数名
|
||||||
|
file_path: 文件路径(可选)
|
||||||
|
include_callers: 是否包含调用者
|
||||||
|
include_callees: 是否包含被调用者
|
||||||
|
top_k: 每类返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含函数定义、调用者、被调用者的字典
|
||||||
|
"""
|
||||||
|
context = {
|
||||||
|
"definition": [],
|
||||||
|
"callers": [],
|
||||||
|
"callees": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 查找函数定义
|
||||||
|
definition_query = f"function definition {function_name}"
|
||||||
|
definitions = await self.retrieve(
|
||||||
|
definition_query,
|
||||||
|
top_k=5,
|
||||||
|
filter_file_path=file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤出真正的定义
|
||||||
|
for result in definitions:
|
||||||
|
if result.name == function_name or function_name in (result.content or ""):
|
||||||
|
context["definition"].append(result)
|
||||||
|
|
||||||
|
if include_callers:
|
||||||
|
# 查找调用此函数的代码
|
||||||
|
caller_query = f"calls {function_name} invoke {function_name}"
|
||||||
|
callers = await self.retrieve(caller_query, top_k=top_k)
|
||||||
|
|
||||||
|
for result in callers:
|
||||||
|
# 检查是否真的调用了这个函数
|
||||||
|
if re.search(rf'\b{re.escape(function_name)}\s*\(', result.content):
|
||||||
|
if result not in context["definition"]:
|
||||||
|
context["callers"].append(result)
|
||||||
|
|
||||||
|
if include_callees and context["definition"]:
|
||||||
|
# 从函数定义中提取调用的其他函数
|
||||||
|
for definition in context["definition"]:
|
||||||
|
calls = re.findall(r'\b(\w+)\s*\(', definition.content)
|
||||||
|
unique_calls = list(set(calls))[:5] # 限制数量
|
||||||
|
|
||||||
|
for call in unique_calls:
|
||||||
|
if call == function_name:
|
||||||
|
continue
|
||||||
|
callees = await self.retrieve(
|
||||||
|
f"function {call} definition",
|
||||||
|
top_k=2,
|
||||||
|
)
|
||||||
|
context["callees"].extend(callees)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def retrieve_similar_code(
|
||||||
|
self,
|
||||||
|
code_snippet: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
exclude_file: Optional[str] = None,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
"""
|
||||||
|
检索相似的代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_snippet: 代码片段
|
||||||
|
top_k: 返回数量
|
||||||
|
exclude_file: 排除的文件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似代码列表
|
||||||
|
"""
|
||||||
|
results = await self.retrieve(
|
||||||
|
f"similar code: {code_snippet}",
|
||||||
|
top_k=top_k * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_file:
|
||||||
|
results = [r for r in results if r.file_path != exclude_file]
|
||||||
|
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
async def hybrid_retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
keywords: Optional[List[str]] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
semantic_weight: float = 0.7,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
"""
|
||||||
|
混合检索(语义 + 关键字)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
keywords: 额外的关键字
|
||||||
|
top_k: 返回数量
|
||||||
|
semantic_weight: 语义检索权重
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表
|
||||||
|
"""
|
||||||
|
# 语义检索
|
||||||
|
semantic_results = await self.retrieve(query, top_k=top_k * 2)
|
||||||
|
|
||||||
|
# 如果有关键字,进行关键字过滤/增强
|
||||||
|
if keywords:
|
||||||
|
keyword_pattern = '|'.join(re.escape(kw) for kw in keywords)
|
||||||
|
|
||||||
|
enhanced_results = []
|
||||||
|
for result in semantic_results:
|
||||||
|
# 计算关键字匹配度
|
||||||
|
matches = len(re.findall(keyword_pattern, result.content, re.IGNORECASE))
|
||||||
|
keyword_score = min(1.0, matches / len(keywords))
|
||||||
|
|
||||||
|
# 混合分数
|
||||||
|
hybrid_score = (
|
||||||
|
semantic_weight * result.score +
|
||||||
|
(1 - semantic_weight) * keyword_score
|
||||||
|
)
|
||||||
|
|
||||||
|
result.score = hybrid_score
|
||||||
|
enhanced_results.append(result)
|
||||||
|
|
||||||
|
enhanced_results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return enhanced_results[:top_k]
|
||||||
|
|
||||||
|
return semantic_results[:top_k]
|
||||||
|
|
||||||
|
def format_results_for_llm(
|
||||||
|
self,
|
||||||
|
results: List[RetrievalResult],
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
include_metadata: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
将检索结果格式化为 LLM 输入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: 检索结果
|
||||||
|
max_tokens: 最大 Token 数
|
||||||
|
include_metadata: 是否包含元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的字符串
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return "No relevant code found."
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
context = result.to_context_string(include_metadata=include_metadata)
|
||||||
|
estimated_tokens = len(context) // 4
|
||||||
|
|
||||||
|
if total_tokens + estimated_tokens > max_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
parts.append(f"### Code Block {i + 1} (Score: {result.score:.2f})\n{context}")
|
||||||
|
total_tokens += estimated_tokens
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,785 @@
|
||||||
|
"""
|
||||||
|
代码分块器 - 基于 Tree-sitter AST 的智能代码分块
|
||||||
|
使用先进的 Python 库实现专业级代码解析
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkType(Enum):
|
||||||
|
"""代码块类型"""
|
||||||
|
FILE = "file"
|
||||||
|
MODULE = "module"
|
||||||
|
CLASS = "class"
|
||||||
|
FUNCTION = "function"
|
||||||
|
METHOD = "method"
|
||||||
|
INTERFACE = "interface"
|
||||||
|
STRUCT = "struct"
|
||||||
|
ENUM = "enum"
|
||||||
|
IMPORT = "import"
|
||||||
|
CONSTANT = "constant"
|
||||||
|
CONFIG = "config"
|
||||||
|
COMMENT = "comment"
|
||||||
|
DECORATOR = "decorator"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeChunk:
|
||||||
|
"""代码块"""
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
file_path: str
|
||||||
|
language: str
|
||||||
|
chunk_type: ChunkType
|
||||||
|
|
||||||
|
# 位置信息
|
||||||
|
line_start: int = 0
|
||||||
|
line_end: int = 0
|
||||||
|
byte_start: int = 0
|
||||||
|
byte_end: int = 0
|
||||||
|
|
||||||
|
# 语义信息
|
||||||
|
name: Optional[str] = None
|
||||||
|
parent_name: Optional[str] = None
|
||||||
|
signature: Optional[str] = None
|
||||||
|
docstring: Optional[str] = None
|
||||||
|
|
||||||
|
# AST 信息
|
||||||
|
ast_type: Optional[str] = None
|
||||||
|
|
||||||
|
# 关联信息
|
||||||
|
imports: List[str] = field(default_factory=list)
|
||||||
|
calls: List[str] = field(default_factory=list)
|
||||||
|
dependencies: List[str] = field(default_factory=list)
|
||||||
|
definitions: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 安全相关
|
||||||
|
security_indicators: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Token 估算
|
||||||
|
estimated_tokens: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.id:
|
||||||
|
self.id = self._generate_id()
|
||||||
|
if not self.estimated_tokens:
|
||||||
|
self.estimated_tokens = self._estimate_tokens()
|
||||||
|
|
||||||
|
def _generate_id(self) -> str:
|
||||||
|
content = f"{self.file_path}:{self.line_start}:{self.line_end}:{self.content[:100]}"
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def _estimate_tokens(self) -> int:
|
||||||
|
# 使用 tiktoken 如果可用
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(enc.encode(self.content))
|
||||||
|
except ImportError:
|
||||||
|
return len(self.content) // 4
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"content": self.content,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"language": self.language,
|
||||||
|
"chunk_type": self.chunk_type.value,
|
||||||
|
"line_start": self.line_start,
|
||||||
|
"line_end": self.line_end,
|
||||||
|
"name": self.name,
|
||||||
|
"parent_name": self.parent_name,
|
||||||
|
"signature": self.signature,
|
||||||
|
"docstring": self.docstring,
|
||||||
|
"ast_type": self.ast_type,
|
||||||
|
"imports": self.imports,
|
||||||
|
"calls": self.calls,
|
||||||
|
"definitions": self.definitions,
|
||||||
|
"security_indicators": self.security_indicators,
|
||||||
|
"estimated_tokens": self.estimated_tokens,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_embedding_text(self) -> str:
|
||||||
|
"""生成用于嵌入的文本"""
|
||||||
|
parts = []
|
||||||
|
parts.append(f"File: {self.file_path}")
|
||||||
|
if self.name:
|
||||||
|
parts.append(f"{self.chunk_type.value.title()}: {self.name}")
|
||||||
|
if self.parent_name:
|
||||||
|
parts.append(f"In: {self.parent_name}")
|
||||||
|
if self.signature:
|
||||||
|
parts.append(f"Signature: {self.signature}")
|
||||||
|
if self.docstring:
|
||||||
|
parts.append(f"Description: {self.docstring[:300]}")
|
||||||
|
parts.append(f"Code:\n{self.content}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class TreeSitterParser:
|
||||||
|
"""
|
||||||
|
基于 Tree-sitter 的代码解析器
|
||||||
|
提供 AST 级别的代码分析
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 语言映射
|
||||||
|
LANGUAGE_MAP = {
|
||||||
|
".py": "python",
|
||||||
|
".js": "javascript",
|
||||||
|
".jsx": "javascript",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "tsx",
|
||||||
|
".java": "java",
|
||||||
|
".go": "go",
|
||||||
|
".rs": "rust",
|
||||||
|
".cpp": "cpp",
|
||||||
|
".c": "c",
|
||||||
|
".h": "c",
|
||||||
|
".hpp": "cpp",
|
||||||
|
".cs": "c_sharp",
|
||||||
|
".php": "php",
|
||||||
|
".rb": "ruby",
|
||||||
|
".kt": "kotlin",
|
||||||
|
".swift": "swift",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 各语言的函数/类节点类型
|
||||||
|
DEFINITION_TYPES = {
|
||||||
|
"python": {
|
||||||
|
"class": ["class_definition"],
|
||||||
|
"function": ["function_definition"],
|
||||||
|
"method": ["function_definition"],
|
||||||
|
"import": ["import_statement", "import_from_statement"],
|
||||||
|
},
|
||||||
|
"javascript": {
|
||||||
|
"class": ["class_declaration", "class"],
|
||||||
|
"function": ["function_declaration", "function", "arrow_function", "method_definition"],
|
||||||
|
"import": ["import_statement"],
|
||||||
|
},
|
||||||
|
"typescript": {
|
||||||
|
"class": ["class_declaration", "class"],
|
||||||
|
"function": ["function_declaration", "function", "arrow_function", "method_definition"],
|
||||||
|
"interface": ["interface_declaration"],
|
||||||
|
"import": ["import_statement"],
|
||||||
|
},
|
||||||
|
"java": {
|
||||||
|
"class": ["class_declaration"],
|
||||||
|
"method": ["method_declaration", "constructor_declaration"],
|
||||||
|
"interface": ["interface_declaration"],
|
||||||
|
"import": ["import_declaration"],
|
||||||
|
},
|
||||||
|
"go": {
|
||||||
|
"struct": ["type_declaration"],
|
||||||
|
"function": ["function_declaration", "method_declaration"],
|
||||||
|
"interface": ["type_declaration"],
|
||||||
|
"import": ["import_declaration"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._parsers: Dict[str, Any] = {}
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def _ensure_initialized(self, language: str) -> bool:
|
||||||
|
"""确保语言解析器已初始化"""
|
||||||
|
if language in self._parsers:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tree_sitter_languages import get_parser, get_language
|
||||||
|
|
||||||
|
parser = get_parser(language)
|
||||||
|
self._parsers[language] = parser
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("tree-sitter-languages not installed, falling back to regex parsing")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load tree-sitter parser for {language}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def parse(self, code: str, language: str) -> Optional[Any]:
|
||||||
|
"""解析代码返回 AST"""
|
||||||
|
if not self._ensure_initialized(language):
|
||||||
|
return None
|
||||||
|
|
||||||
|
parser = self._parsers.get(language)
|
||||||
|
if not parser:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
tree = parser.parse(code.encode())
|
||||||
|
return tree
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse code: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
||||||
|
"""从 AST 提取定义"""
|
||||||
|
if tree is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
definitions = []
|
||||||
|
definition_types = self.DEFINITION_TYPES.get(language, {})
|
||||||
|
|
||||||
|
def traverse(node, parent_name=None):
|
||||||
|
node_type = node.type
|
||||||
|
|
||||||
|
# 检查是否是定义节点
|
||||||
|
for def_category, types in definition_types.items():
|
||||||
|
if node_type in types:
|
||||||
|
name = self._extract_name(node, language)
|
||||||
|
|
||||||
|
definitions.append({
|
||||||
|
"type": def_category,
|
||||||
|
"name": name,
|
||||||
|
"parent_name": parent_name,
|
||||||
|
"start_point": node.start_point,
|
||||||
|
"end_point": node.end_point,
|
||||||
|
"start_byte": node.start_byte,
|
||||||
|
"end_byte": node.end_byte,
|
||||||
|
"node_type": node_type,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 对于类,继续遍历子节点找方法
|
||||||
|
if def_category == "class":
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child, name)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 继续遍历子节点
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child, parent_name)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
return definitions
|
||||||
|
|
||||||
|
def _extract_name(self, node: Any, language: str) -> Optional[str]:
|
||||||
|
"""从节点提取名称"""
|
||||||
|
# 查找 identifier 子节点
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ["identifier", "name", "type_identifier", "property_identifier"]:
|
||||||
|
return child.text.decode() if isinstance(child.text, bytes) else child.text
|
||||||
|
|
||||||
|
# 对于某些语言的特殊处理
|
||||||
|
if language == "python":
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == "name":
|
||||||
|
return child.text.decode() if isinstance(child.text, bytes) else child.text
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CodeSplitter:
|
||||||
|
"""
|
||||||
|
高级代码分块器
|
||||||
|
使用 Tree-sitter 进行 AST 解析,支持多种编程语言
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 危险函数/模式(用于安全指标)
|
||||||
|
SECURITY_PATTERNS = {
|
||||||
|
"python": [
|
||||||
|
(r"\bexec\s*\(", "exec"),
|
||||||
|
(r"\beval\s*\(", "eval"),
|
||||||
|
(r"\bcompile\s*\(", "compile"),
|
||||||
|
(r"\bos\.system\s*\(", "os_system"),
|
||||||
|
(r"\bsubprocess\.", "subprocess"),
|
||||||
|
(r"\bcursor\.execute\s*\(", "sql_execute"),
|
||||||
|
(r"\.execute\s*\(.*%", "sql_format"),
|
||||||
|
(r"\bpickle\.loads?\s*\(", "pickle"),
|
||||||
|
(r"\byaml\.load\s*\(", "yaml_load"),
|
||||||
|
(r"\brequests?\.", "http_request"),
|
||||||
|
(r"password\s*=", "password_assign"),
|
||||||
|
(r"secret\s*=", "secret_assign"),
|
||||||
|
(r"api_key\s*=", "api_key_assign"),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r"\beval\s*\(", "eval"),
|
||||||
|
(r"\bFunction\s*\(", "function_constructor"),
|
||||||
|
(r"innerHTML\s*=", "innerHTML"),
|
||||||
|
(r"outerHTML\s*=", "outerHTML"),
|
||||||
|
(r"document\.write\s*\(", "document_write"),
|
||||||
|
(r"\.exec\s*\(", "exec"),
|
||||||
|
(r"\.query\s*\(.*\+", "sql_concat"),
|
||||||
|
(r"password\s*[=:]", "password_assign"),
|
||||||
|
(r"apiKey\s*[=:]", "api_key_assign"),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"),
|
||||||
|
(r"ProcessBuilder", "process_builder"),
|
||||||
|
(r"\.executeQuery\s*\(.*\+", "sql_concat"),
|
||||||
|
(r"ObjectInputStream", "deserialization"),
|
||||||
|
(r"XMLDecoder", "xml_decoder"),
|
||||||
|
(r"password\s*=", "password_assign"),
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
(r"exec\.Command\s*\(", "exec_command"),
|
||||||
|
(r"\.Query\s*\(.*\+", "sql_concat"),
|
||||||
|
(r"\.Exec\s*\(.*\+", "sql_concat"),
|
||||||
|
(r"template\.HTML\s*\(", "unsafe_html"),
|
||||||
|
(r"password\s*=", "password_assign"),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r"\beval\s*\(", "eval"),
|
||||||
|
(r"\bexec\s*\(", "exec"),
|
||||||
|
(r"\bsystem\s*\(", "system"),
|
||||||
|
(r"\bshell_exec\s*\(", "shell_exec"),
|
||||||
|
(r"\$_GET\[", "get_input"),
|
||||||
|
(r"\$_POST\[", "post_input"),
|
||||||
|
(r"\$_REQUEST\[", "request_input"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_chunk_size: int = 1500,
|
||||||
|
min_chunk_size: int = 100,
|
||||||
|
overlap_size: int = 50,
|
||||||
|
preserve_structure: bool = True,
|
||||||
|
use_tree_sitter: bool = True,
|
||||||
|
):
|
||||||
|
self.max_chunk_size = max_chunk_size
|
||||||
|
self.min_chunk_size = min_chunk_size
|
||||||
|
self.overlap_size = overlap_size
|
||||||
|
self.preserve_structure = preserve_structure
|
||||||
|
self.use_tree_sitter = use_tree_sitter
|
||||||
|
|
||||||
|
self._ts_parser = TreeSitterParser() if use_tree_sitter else None
|
||||||
|
|
||||||
|
def detect_language(self, file_path: str) -> str:
|
||||||
|
"""检测编程语言"""
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
return TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
|
||||||
|
|
||||||
|
def split_file(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
language: Optional[str] = None
|
||||||
|
) -> List[CodeChunk]:
|
||||||
|
"""
|
||||||
|
分割单个文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 文件内容
|
||||||
|
file_path: 文件路径
|
||||||
|
language: 编程语言(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
代码块列表
|
||||||
|
"""
|
||||||
|
if not content or not content.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
if language is None:
|
||||||
|
language = self.detect_language(file_path)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试使用 Tree-sitter 解析
|
||||||
|
if self.use_tree_sitter and self._ts_parser:
|
||||||
|
tree = self._ts_parser.parse(content, language)
|
||||||
|
if tree:
|
||||||
|
chunks = self._split_by_ast(content, file_path, language, tree)
|
||||||
|
|
||||||
|
# 如果 AST 解析失败或没有结果,使用增强的正则解析
|
||||||
|
if not chunks:
|
||||||
|
chunks = self._split_by_enhanced_regex(content, file_path, language)
|
||||||
|
|
||||||
|
# 如果还是没有,使用基于行的分块
|
||||||
|
if not chunks:
|
||||||
|
chunks = self._split_by_lines(content, file_path, language)
|
||||||
|
|
||||||
|
# 后处理:提取安全指标
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk.security_indicators = self._extract_security_indicators(
|
||||||
|
chunk.content, language
|
||||||
|
)
|
||||||
|
|
||||||
|
# 后处理:使用语义分析增强
|
||||||
|
self._enrich_chunks_with_semantics(chunks, content, language)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
|
||||||
|
chunks = self._split_by_lines(content, file_path, language)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_by_ast(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
language: str,
|
||||||
|
tree: Any
|
||||||
|
) -> List[CodeChunk]:
|
||||||
|
"""基于 AST 分块"""
|
||||||
|
chunks = []
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# 提取定义
|
||||||
|
definitions = self._ts_parser.extract_definitions(tree, content, language)
|
||||||
|
|
||||||
|
if not definitions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 为每个定义创建代码块
|
||||||
|
for defn in definitions:
|
||||||
|
start_line = defn["start_point"][0]
|
||||||
|
end_line = defn["end_point"][0]
|
||||||
|
|
||||||
|
# 提取代码内容
|
||||||
|
chunk_lines = lines[start_line:end_line + 1]
|
||||||
|
chunk_content = '\n'.join(chunk_lines)
|
||||||
|
|
||||||
|
if len(chunk_content.strip()) < self.min_chunk_size // 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_type = ChunkType.CLASS if defn["type"] == "class" else \
|
||||||
|
ChunkType.FUNCTION if defn["type"] in ["function", "method"] else \
|
||||||
|
ChunkType.INTERFACE if defn["type"] == "interface" else \
|
||||||
|
ChunkType.STRUCT if defn["type"] == "struct" else \
|
||||||
|
ChunkType.IMPORT if defn["type"] == "import" else \
|
||||||
|
ChunkType.MODULE
|
||||||
|
|
||||||
|
chunk = CodeChunk(
|
||||||
|
id="",
|
||||||
|
content=chunk_content,
|
||||||
|
file_path=file_path,
|
||||||
|
language=language,
|
||||||
|
chunk_type=chunk_type,
|
||||||
|
line_start=start_line + 1,
|
||||||
|
line_end=end_line + 1,
|
||||||
|
byte_start=defn["start_byte"],
|
||||||
|
byte_end=defn["end_byte"],
|
||||||
|
name=defn.get("name"),
|
||||||
|
parent_name=defn.get("parent_name"),
|
||||||
|
ast_type=defn.get("node_type"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果块太大,进一步分割
|
||||||
|
if chunk.estimated_tokens > self.max_chunk_size:
|
||||||
|
sub_chunks = self._split_large_chunk(chunk)
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
else:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_by_enhanced_regex(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
language: str
|
||||||
|
) -> List[CodeChunk]:
|
||||||
|
"""增强的正则表达式分块(支持更多语言)"""
|
||||||
|
chunks = []
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# 各语言的定义模式
|
||||||
|
patterns = {
|
||||||
|
"python": [
|
||||||
|
(r"^(\s*)class\s+(\w+)(?:\s*\([^)]*\))?\s*:", ChunkType.CLASS),
|
||||||
|
(r"^(\s*)(?:async\s+)?def\s+(\w+)\s*\([^)]*\)\s*(?:->[^:]+)?:", ChunkType.FUNCTION),
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
(r"^(\s*)(?:export\s+)?class\s+(\w+)", ChunkType.CLASS),
|
||||||
|
(r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)\s*\(", ChunkType.FUNCTION),
|
||||||
|
(r"^(\s*)(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>", ChunkType.FUNCTION),
|
||||||
|
],
|
||||||
|
"typescript": [
|
||||||
|
(r"^(\s*)(?:export\s+)?(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS),
|
||||||
|
(r"^(\s*)(?:export\s+)?interface\s+(\w+)", ChunkType.INTERFACE),
|
||||||
|
(r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)", ChunkType.FUNCTION),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?class\s+(\w+)", ChunkType.CLASS),
|
||||||
|
(r"^(\s*)(?:public|private|protected)?\s*interface\s+(\w+)", ChunkType.INTERFACE),
|
||||||
|
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+[\w,\s]+)?\s*\{", ChunkType.METHOD),
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
(r"^type\s+(\w+)\s+struct\s*\{", ChunkType.STRUCT),
|
||||||
|
(r"^type\s+(\w+)\s+interface\s*\{", ChunkType.INTERFACE),
|
||||||
|
(r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\([^)]*\)", ChunkType.FUNCTION),
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
(r"^(\s*)(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS),
|
||||||
|
(r"^(\s*)interface\s+(\w+)", ChunkType.INTERFACE),
|
||||||
|
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?function\s+(\w+)", ChunkType.FUNCTION),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
lang_patterns = patterns.get(language, [])
|
||||||
|
if not lang_patterns:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 找到所有定义的位置
|
||||||
|
definitions = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
for pattern, chunk_type in lang_patterns:
|
||||||
|
match = re.match(pattern, line)
|
||||||
|
if match:
|
||||||
|
indent = len(match.group(1)) if match.lastindex >= 1 else 0
|
||||||
|
name = match.group(2) if match.lastindex >= 2 else None
|
||||||
|
definitions.append({
|
||||||
|
"line": i,
|
||||||
|
"indent": indent,
|
||||||
|
"name": name,
|
||||||
|
"type": chunk_type,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
if not definitions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 计算每个定义的范围
|
||||||
|
for i, defn in enumerate(definitions):
|
||||||
|
start_line = defn["line"]
|
||||||
|
base_indent = defn["indent"]
|
||||||
|
|
||||||
|
# 查找结束位置
|
||||||
|
end_line = len(lines) - 1
|
||||||
|
for j in range(start_line + 1, len(lines)):
|
||||||
|
line = lines[j]
|
||||||
|
if line.strip():
|
||||||
|
current_indent = len(line) - len(line.lstrip())
|
||||||
|
# 如果缩进回到基础级别,检查是否是下一个定义
|
||||||
|
if current_indent <= base_indent:
|
||||||
|
# 检查是否是下一个定义
|
||||||
|
is_next_def = any(d["line"] == j for d in definitions)
|
||||||
|
if is_next_def or (current_indent < base_indent):
|
||||||
|
end_line = j - 1
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk_content = '\n'.join(lines[start_line:end_line + 1])
|
||||||
|
|
||||||
|
if len(chunk_content.strip()) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = CodeChunk(
|
||||||
|
id="",
|
||||||
|
content=chunk_content,
|
||||||
|
file_path=file_path,
|
||||||
|
language=language,
|
||||||
|
chunk_type=defn["type"],
|
||||||
|
line_start=start_line + 1,
|
||||||
|
line_end=end_line + 1,
|
||||||
|
name=defn.get("name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunk.estimated_tokens > self.max_chunk_size:
|
||||||
|
sub_chunks = self._split_large_chunk(chunk)
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
else:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_by_lines(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
language: str
|
||||||
|
) -> List[CodeChunk]:
|
||||||
|
"""基于行数分块(回退方案)"""
|
||||||
|
chunks = []
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# 估算每行 Token 数
|
||||||
|
total_tokens = len(content) // 4
|
||||||
|
avg_tokens_per_line = max(1, total_tokens // max(1, len(lines)))
|
||||||
|
lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line)
|
||||||
|
overlap_lines = self.overlap_size // avg_tokens_per_line
|
||||||
|
|
||||||
|
for i in range(0, len(lines), lines_per_chunk - overlap_lines):
|
||||||
|
end = min(i + lines_per_chunk, len(lines))
|
||||||
|
chunk_content = '\n'.join(lines[i:end])
|
||||||
|
|
||||||
|
if len(chunk_content.strip()) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = CodeChunk(
|
||||||
|
id="",
|
||||||
|
content=chunk_content,
|
||||||
|
file_path=file_path,
|
||||||
|
language=language,
|
||||||
|
chunk_type=ChunkType.MODULE,
|
||||||
|
line_start=i + 1,
|
||||||
|
line_end=end,
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
if end >= len(lines):
|
||||||
|
break
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]:
|
||||||
|
"""分割过大的代码块"""
|
||||||
|
sub_chunks = []
|
||||||
|
lines = chunk.content.split('\n')
|
||||||
|
|
||||||
|
avg_tokens_per_line = max(1, chunk.estimated_tokens // max(1, len(lines)))
|
||||||
|
lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line)
|
||||||
|
|
||||||
|
for i in range(0, len(lines), lines_per_chunk):
|
||||||
|
end = min(i + lines_per_chunk, len(lines))
|
||||||
|
sub_content = '\n'.join(lines[i:end])
|
||||||
|
|
||||||
|
if len(sub_content.strip()) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sub_chunk = CodeChunk(
|
||||||
|
id="",
|
||||||
|
content=sub_content,
|
||||||
|
file_path=chunk.file_path,
|
||||||
|
language=chunk.language,
|
||||||
|
chunk_type=chunk.chunk_type,
|
||||||
|
line_start=chunk.line_start + i,
|
||||||
|
line_end=chunk.line_start + end - 1,
|
||||||
|
name=chunk.name,
|
||||||
|
parent_name=chunk.parent_name,
|
||||||
|
)
|
||||||
|
sub_chunks.append(sub_chunk)
|
||||||
|
|
||||||
|
return sub_chunks if sub_chunks else [chunk]
|
||||||
|
|
||||||
|
def _extract_security_indicators(self, content: str, language: str) -> List[str]:
|
||||||
|
"""提取安全相关指标"""
|
||||||
|
indicators = []
|
||||||
|
patterns = self.SECURITY_PATTERNS.get(language, [])
|
||||||
|
|
||||||
|
# 添加通用模式
|
||||||
|
common_patterns = [
|
||||||
|
(r"password", "password"),
|
||||||
|
(r"secret", "secret"),
|
||||||
|
(r"api[_-]?key", "api_key"),
|
||||||
|
(r"token", "token"),
|
||||||
|
(r"private[_-]?key", "private_key"),
|
||||||
|
(r"credential", "credential"),
|
||||||
|
]
|
||||||
|
|
||||||
|
all_patterns = patterns + common_patterns
|
||||||
|
|
||||||
|
for pattern, name in all_patterns:
|
||||||
|
try:
|
||||||
|
if re.search(pattern, content, re.IGNORECASE):
|
||||||
|
if name not in indicators:
|
||||||
|
indicators.append(name)
|
||||||
|
except re.error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return indicators[:15]
|
||||||
|
|
||||||
|
def _enrich_chunks_with_semantics(
|
||||||
|
self,
|
||||||
|
chunks: List[CodeChunk],
|
||||||
|
full_content: str,
|
||||||
|
language: str
|
||||||
|
):
|
||||||
|
"""使用语义分析增强代码块"""
|
||||||
|
# 提取导入
|
||||||
|
imports = self._extract_imports(full_content, language)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
# 添加相关导入
|
||||||
|
chunk.imports = self._filter_relevant_imports(imports, chunk.content)
|
||||||
|
|
||||||
|
# 提取函数调用
|
||||||
|
chunk.calls = self._extract_function_calls(chunk.content, language)
|
||||||
|
|
||||||
|
# 提取定义
|
||||||
|
chunk.definitions = self._extract_definitions(chunk.content, language)
|
||||||
|
|
||||||
|
def _extract_imports(self, content: str, language: str) -> List[str]:
|
||||||
|
"""提取导入语句"""
|
||||||
|
imports = []
|
||||||
|
|
||||||
|
patterns = {
|
||||||
|
"python": [
|
||||||
|
r"^import\s+([\w.]+)",
|
||||||
|
r"^from\s+([\w.]+)\s+import",
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
|
||||||
|
r"require\s*\(['\"]([^'\"]+)['\"]\)",
|
||||||
|
],
|
||||||
|
"typescript": [
|
||||||
|
r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
r"^import\s+([\w.]+);",
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
r"['\"]([^'\"]+)['\"]",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
for pattern in patterns.get(language, []):
|
||||||
|
matches = re.findall(pattern, content, re.MULTILINE)
|
||||||
|
imports.extend(matches)
|
||||||
|
|
||||||
|
return list(set(imports))
|
||||||
|
|
||||||
|
def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]:
|
||||||
|
"""过滤与代码块相关的导入"""
|
||||||
|
relevant = []
|
||||||
|
for imp in all_imports:
|
||||||
|
module_name = imp.split('.')[-1]
|
||||||
|
if re.search(rf'\b{re.escape(module_name)}\b', chunk_content):
|
||||||
|
relevant.append(imp)
|
||||||
|
return relevant[:20]
|
||||||
|
|
||||||
|
def _extract_function_calls(self, content: str, language: str) -> List[str]:
|
||||||
|
"""提取函数调用"""
|
||||||
|
pattern = r'\b(\w+)\s*\('
|
||||||
|
matches = re.findall(pattern, content)
|
||||||
|
|
||||||
|
keywords = {
|
||||||
|
"python": {"if", "for", "while", "with", "def", "class", "return", "except", "print", "assert", "lambda"},
|
||||||
|
"javascript": {"if", "for", "while", "switch", "function", "return", "catch", "console", "async", "await"},
|
||||||
|
"java": {"if", "for", "while", "switch", "return", "catch", "throw", "new"},
|
||||||
|
"go": {"if", "for", "switch", "return", "func", "go", "defer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
lang_keywords = keywords.get(language, set())
|
||||||
|
calls = [m for m in matches if m not in lang_keywords]
|
||||||
|
|
||||||
|
return list(set(calls))[:30]
|
||||||
|
|
||||||
|
def _extract_definitions(self, content: str, language: str) -> List[str]:
|
||||||
|
"""提取定义的标识符"""
|
||||||
|
definitions = []
|
||||||
|
|
||||||
|
patterns = {
|
||||||
|
"python": [
|
||||||
|
r"def\s+(\w+)\s*\(",
|
||||||
|
r"class\s+(\w+)",
|
||||||
|
r"(\w+)\s*=\s*",
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
r"function\s+(\w+)",
|
||||||
|
r"(?:const|let|var)\s+(\w+)",
|
||||||
|
r"class\s+(\w+)",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
for pattern in patterns.get(language, []):
|
||||||
|
matches = re.findall(pattern, content)
|
||||||
|
definitions.extend(matches)
|
||||||
|
|
||||||
|
return list(set(definitions))[:20]
|
||||||
|
|
||||||
|
|
@ -17,3 +17,51 @@ reportlab>=4.0.0
|
||||||
weasyprint>=66.0
|
weasyprint>=66.0
|
||||||
jinja2>=3.1.6
|
jinja2>=3.1.6
|
||||||
json-repair>=0.30.0
|
json-repair>=0.30.0
|
||||||
|
|
||||||
|
# ============ Agent 模块依赖 ============
|
||||||
|
|
||||||
|
# LangChain 核心
|
||||||
|
langchain>=0.1.0
|
||||||
|
langchain-community>=0.0.20
|
||||||
|
langchain-openai>=0.0.5
|
||||||
|
|
||||||
|
# LangGraph (状态图工作流)
|
||||||
|
langgraph>=0.0.40
|
||||||
|
|
||||||
|
# 向量数据库
|
||||||
|
chromadb>=0.4.22
|
||||||
|
|
||||||
|
# Token 计算
|
||||||
|
tiktoken>=0.5.2
|
||||||
|
|
||||||
|
# Docker 沙箱
|
||||||
|
docker>=7.0.0
|
||||||
|
|
||||||
|
# 异步文件操作
|
||||||
|
aiofiles>=23.2.1
|
||||||
|
|
||||||
|
# SSE 流
|
||||||
|
sse-starlette>=1.8.2
|
||||||
|
|
||||||
|
# ============ 代码解析 (高级库) ============
|
||||||
|
|
||||||
|
# Tree-sitter AST 解析
|
||||||
|
tree-sitter>=0.21.0
|
||||||
|
tree-sitter-languages>=1.10.0
|
||||||
|
|
||||||
|
# 通用代码解析
|
||||||
|
pygments>=2.17.0
|
||||||
|
|
||||||
|
# ============ 外部安全工具 (可选安装) ============
|
||||||
|
# 这些工具可以通过 pip 安装,或使用系统包管理器
|
||||||
|
|
||||||
|
# Python 安全扫描
|
||||||
|
bandit>=1.7.0
|
||||||
|
safety>=2.3.0
|
||||||
|
|
||||||
|
# 静态分析 (需要单独安装 semgrep CLI)
|
||||||
|
# pip install semgrep
|
||||||
|
|
||||||
|
# 依赖漏洞扫描
|
||||||
|
pip-audit>=2.6.0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
# DeepAudit Agent Sandbox
|
||||||
|
# 安全沙箱环境用于漏洞验证和 PoC 执行
|
||||||
|
|
||||||
|
FROM python:3.11-slim-bookworm
|
||||||
|
|
||||||
|
LABEL maintainer="XCodeReviewer Team"
|
||||||
|
LABEL description="Secure sandbox environment for vulnerability verification"
|
||||||
|
|
||||||
|
# 安装基本工具
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
wget \
|
||||||
|
netcat-openbsd \
|
||||||
|
dnsutils \
|
||||||
|
iputils-ping \
|
||||||
|
ca-certificates \
|
||||||
|
git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 安装 Node.js (用于 JavaScript/TypeScript 代码执行)
|
||||||
|
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||||
|
&& apt-get install -y nodejs \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 安装常用的安全测试 Python 库
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
requests \
|
||||||
|
httpx \
|
||||||
|
aiohttp \
|
||||||
|
beautifulsoup4 \
|
||||||
|
lxml \
|
||||||
|
pycryptodome \
|
||||||
|
paramiko \
|
||||||
|
pyjwt \
|
||||||
|
python-jose \
|
||||||
|
sqlparse
|
||||||
|
|
||||||
|
# 创建非 root 用户
|
||||||
|
RUN groupadd -g 1000 sandbox && \
|
||||||
|
useradd -u 1000 -g sandbox -m -s /bin/bash sandbox
|
||||||
|
|
||||||
|
# 创建工作目录
|
||||||
|
RUN mkdir -p /workspace /tmp/sandbox && \
|
||||||
|
chown -R sandbox:sandbox /workspace /tmp/sandbox
|
||||||
|
|
||||||
|
# 设置环境变量
|
||||||
|
ENV HOME=/home/sandbox
|
||||||
|
ENV PATH=/home/sandbox/.local/bin:$PATH
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# 限制 Python 导入路径
|
||||||
|
ENV PYTHONPATH=/workspace
|
||||||
|
|
||||||
|
# 切换到非 root 用户
|
||||||
|
USER sandbox
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# 默认命令
|
||||||
|
CMD ["/bin/bash"]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# 构建沙箱镜像
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
IMAGE_NAME="deepaudit-sandbox"
|
||||||
|
IMAGE_TAG="latest"
|
||||||
|
|
||||||
|
echo "Building sandbox image: ${IMAGE_NAME}:${IMAGE_TAG}"
|
||||||
|
|
||||||
|
docker build \
|
||||||
|
-t "${IMAGE_NAME}:${IMAGE_TAG}" \
|
||||||
|
-f "${SCRIPT_DIR}/Dockerfile" \
|
||||||
|
"${SCRIPT_DIR}"
|
||||||
|
|
||||||
|
echo "Build complete: ${IMAGE_NAME}:${IMAGE_TAG}"
|
||||||
|
|
||||||
|
# 验证镜像
|
||||||
|
echo "Verifying image..."
|
||||||
|
docker run --rm "${IMAGE_NAME}:${IMAGE_TAG}" python3 --version
|
||||||
|
docker run --rm "${IMAGE_NAME}:${IMAGE_TAG}" node --version
|
||||||
|
|
||||||
|
echo "Sandbox image ready!"
|
||||||
|
|
||||||
|
|
@ -0,0 +1,266 @@
|
||||||
|
{
|
||||||
|
"defaultAction": "SCMP_ACT_ERRNO",
|
||||||
|
"defaultErrnoRet": 1,
|
||||||
|
"archMap": [
|
||||||
|
{
|
||||||
|
"architecture": "SCMP_ARCH_X86_64",
|
||||||
|
"subArchitectures": [
|
||||||
|
"SCMP_ARCH_X86",
|
||||||
|
"SCMP_ARCH_X32"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"architecture": "SCMP_ARCH_AARCH64",
|
||||||
|
"subArchitectures": [
|
||||||
|
"SCMP_ARCH_ARM"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"syscalls": [
|
||||||
|
{
|
||||||
|
"names": [
|
||||||
|
"accept",
|
||||||
|
"accept4",
|
||||||
|
"access",
|
||||||
|
"arch_prctl",
|
||||||
|
"bind",
|
||||||
|
"brk",
|
||||||
|
"capget",
|
||||||
|
"capset",
|
||||||
|
"chdir",
|
||||||
|
"chmod",
|
||||||
|
"chown",
|
||||||
|
"clock_getres",
|
||||||
|
"clock_gettime",
|
||||||
|
"clock_nanosleep",
|
||||||
|
"clone",
|
||||||
|
"clone3",
|
||||||
|
"close",
|
||||||
|
"connect",
|
||||||
|
"copy_file_range",
|
||||||
|
"dup",
|
||||||
|
"dup2",
|
||||||
|
"dup3",
|
||||||
|
"epoll_create",
|
||||||
|
"epoll_create1",
|
||||||
|
"epoll_ctl",
|
||||||
|
"epoll_pwait",
|
||||||
|
"epoll_pwait2",
|
||||||
|
"epoll_wait",
|
||||||
|
"eventfd",
|
||||||
|
"eventfd2",
|
||||||
|
"execve",
|
||||||
|
"execveat",
|
||||||
|
"exit",
|
||||||
|
"exit_group",
|
||||||
|
"faccessat",
|
||||||
|
"faccessat2",
|
||||||
|
"fadvise64",
|
||||||
|
"fallocate",
|
||||||
|
"fchdir",
|
||||||
|
"fchmod",
|
||||||
|
"fchmodat",
|
||||||
|
"fchown",
|
||||||
|
"fchownat",
|
||||||
|
"fcntl",
|
||||||
|
"fdatasync",
|
||||||
|
"fgetxattr",
|
||||||
|
"flistxattr",
|
||||||
|
"flock",
|
||||||
|
"fork",
|
||||||
|
"fstat",
|
||||||
|
"fstatfs",
|
||||||
|
"fsync",
|
||||||
|
"ftruncate",
|
||||||
|
"futex",
|
||||||
|
"getcwd",
|
||||||
|
"getdents",
|
||||||
|
"getdents64",
|
||||||
|
"getegid",
|
||||||
|
"geteuid",
|
||||||
|
"getgid",
|
||||||
|
"getgroups",
|
||||||
|
"getpeername",
|
||||||
|
"getpgid",
|
||||||
|
"getpgrp",
|
||||||
|
"getpid",
|
||||||
|
"getppid",
|
||||||
|
"getpriority",
|
||||||
|
"getrandom",
|
||||||
|
"getresgid",
|
||||||
|
"getresuid",
|
||||||
|
"getrlimit",
|
||||||
|
"getrusage",
|
||||||
|
"getsid",
|
||||||
|
"getsockname",
|
||||||
|
"getsockopt",
|
||||||
|
"gettid",
|
||||||
|
"gettimeofday",
|
||||||
|
"getuid",
|
||||||
|
"getxattr",
|
||||||
|
"inotify_add_watch",
|
||||||
|
"inotify_init",
|
||||||
|
"inotify_init1",
|
||||||
|
"inotify_rm_watch",
|
||||||
|
"ioctl",
|
||||||
|
"kill",
|
||||||
|
"lchown",
|
||||||
|
"lgetxattr",
|
||||||
|
"link",
|
||||||
|
"linkat",
|
||||||
|
"listen",
|
||||||
|
"listxattr",
|
||||||
|
"llistxattr",
|
||||||
|
"lseek",
|
||||||
|
"lstat",
|
||||||
|
"madvise",
|
||||||
|
"membarrier",
|
||||||
|
"memfd_create",
|
||||||
|
"mincore",
|
||||||
|
"mkdir",
|
||||||
|
"mkdirat",
|
||||||
|
"mknod",
|
||||||
|
"mknodat",
|
||||||
|
"mlock",
|
||||||
|
"mlock2",
|
||||||
|
"mlockall",
|
||||||
|
"mmap",
|
||||||
|
"mprotect",
|
||||||
|
"mremap",
|
||||||
|
"msync",
|
||||||
|
"munlock",
|
||||||
|
"munlockall",
|
||||||
|
"munmap",
|
||||||
|
"name_to_handle_at",
|
||||||
|
"nanosleep",
|
||||||
|
"newfstatat",
|
||||||
|
"open",
|
||||||
|
"openat",
|
||||||
|
"openat2",
|
||||||
|
"pause",
|
||||||
|
"pipe",
|
||||||
|
"pipe2",
|
||||||
|
"poll",
|
||||||
|
"ppoll",
|
||||||
|
"prctl",
|
||||||
|
"pread64",
|
||||||
|
"preadv",
|
||||||
|
"preadv2",
|
||||||
|
"prlimit64",
|
||||||
|
"pselect6",
|
||||||
|
"pwrite64",
|
||||||
|
"pwritev",
|
||||||
|
"pwritev2",
|
||||||
|
"read",
|
||||||
|
"readahead",
|
||||||
|
"readlink",
|
||||||
|
"readlinkat",
|
||||||
|
"readv",
|
||||||
|
"recv",
|
||||||
|
"recvfrom",
|
||||||
|
"recvmmsg",
|
||||||
|
"recvmsg",
|
||||||
|
"rename",
|
||||||
|
"renameat",
|
||||||
|
"renameat2",
|
||||||
|
"restart_syscall",
|
||||||
|
"rmdir",
|
||||||
|
"rseq",
|
||||||
|
"rt_sigaction",
|
||||||
|
"rt_sigpending",
|
||||||
|
"rt_sigprocmask",
|
||||||
|
"rt_sigqueueinfo",
|
||||||
|
"rt_sigreturn",
|
||||||
|
"rt_sigsuspend",
|
||||||
|
"rt_sigtimedwait",
|
||||||
|
"rt_tgsigqueueinfo",
|
||||||
|
"sched_getaffinity",
|
||||||
|
"sched_getattr",
|
||||||
|
"sched_getparam",
|
||||||
|
"sched_get_priority_max",
|
||||||
|
"sched_get_priority_min",
|
||||||
|
"sched_getscheduler",
|
||||||
|
"sched_rr_get_interval",
|
||||||
|
"sched_setaffinity",
|
||||||
|
"sched_setattr",
|
||||||
|
"sched_setparam",
|
||||||
|
"sched_setscheduler",
|
||||||
|
"sched_yield",
|
||||||
|
"seccomp",
|
||||||
|
"select",
|
||||||
|
"semctl",
|
||||||
|
"semget",
|
||||||
|
"semop",
|
||||||
|
"semtimedop",
|
||||||
|
"send",
|
||||||
|
"sendfile",
|
||||||
|
"sendmmsg",
|
||||||
|
"sendmsg",
|
||||||
|
"sendto",
|
||||||
|
"set_robust_list",
|
||||||
|
"set_tid_address",
|
||||||
|
"setgid",
|
||||||
|
"setgroups",
|
||||||
|
"setitimer",
|
||||||
|
"setpgid",
|
||||||
|
"setpriority",
|
||||||
|
"setregid",
|
||||||
|
"setresgid",
|
||||||
|
"setresuid",
|
||||||
|
"setreuid",
|
||||||
|
"setsid",
|
||||||
|
"setsockopt",
|
||||||
|
"setuid",
|
||||||
|
"shmat",
|
||||||
|
"shmctl",
|
||||||
|
"shmdt",
|
||||||
|
"shmget",
|
||||||
|
"shutdown",
|
||||||
|
"sigaltstack",
|
||||||
|
"signalfd",
|
||||||
|
"signalfd4",
|
||||||
|
"socket",
|
||||||
|
"socketpair",
|
||||||
|
"splice",
|
||||||
|
"stat",
|
||||||
|
"statfs",
|
||||||
|
"statx",
|
||||||
|
"symlink",
|
||||||
|
"symlinkat",
|
||||||
|
"sync",
|
||||||
|
"sync_file_range",
|
||||||
|
"syncfs",
|
||||||
|
"sysinfo",
|
||||||
|
"tee",
|
||||||
|
"tgkill",
|
||||||
|
"time",
|
||||||
|
"timer_create",
|
||||||
|
"timer_delete",
|
||||||
|
"timer_getoverrun",
|
||||||
|
"timer_gettime",
|
||||||
|
"timer_settime",
|
||||||
|
"timerfd_create",
|
||||||
|
"timerfd_gettime",
|
||||||
|
"timerfd_settime",
|
||||||
|
"times",
|
||||||
|
"tkill",
|
||||||
|
"truncate",
|
||||||
|
"umask",
|
||||||
|
"uname",
|
||||||
|
"unlink",
|
||||||
|
"unlinkat",
|
||||||
|
"utime",
|
||||||
|
"utimensat",
|
||||||
|
"utimes",
|
||||||
|
"vfork",
|
||||||
|
"wait4",
|
||||||
|
"waitid",
|
||||||
|
"waitpid",
|
||||||
|
"write",
|
||||||
|
"writev"
|
||||||
|
],
|
||||||
|
"action": "SCMP_ACT_ALLOW"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,298 @@
|
||||||
|
# DeepAudit Agent 审计模块
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
Agent 审计模块是 DeepAudit 的高级安全审计功能,基于 **LangGraph** 状态图构建的混合 AI Agent 架构,实现自主代码安全分析和漏洞验证。
|
||||||
|
|
||||||
|
## LangGraph 工作流架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ LangGraph 审计工作流 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ START │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌────────────────────────────────────────────────────────────────┐│
|
||||||
|
│ │ Recon Node (信息收集) ││
|
||||||
|
│ │ • 项目结构分析 • 技术栈识别 ││
|
||||||
|
│ │ • 入口点发现 • 依赖扫描 ││
|
||||||
|
│ │ ││
|
||||||
|
│ │ 使用工具: list_files, npm_audit, safety_scan, gitleaks_scan ││
|
||||||
|
│ └────────────────────────────┬───────────────────────────────────┘│
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌────────────────────────────────────────────────────────────────┐│
|
||||||
|
│ │ Analysis Node (漏洞分析) ││
|
||||||
|
│ │ • Semgrep 静态分析 • RAG 语义搜索 ││
|
||||||
|
│ │ • 模式匹配 • LLM 深度分析 ││
|
||||||
|
│ │ • 数据流追踪 ││
|
||||||
|
│ │ ◄─────┐ ││
|
||||||
|
│ │ 使用工具: semgrep_scan, bandit_scan, rag_query, │ ││
|
||||||
|
│ │ code_analysis, pattern_match │ ││
|
||||||
|
│ └────────────────────────────┬──────────────────────────┘───────┘│
|
||||||
|
│ │ │ │
|
||||||
|
│ ▼ │ │
|
||||||
|
│ ┌────────────────────────────────────────────────────────────────┐│
|
||||||
|
│ │ Verification Node (漏洞验证) ││
|
||||||
|
│ │ • LLM 漏洞验证 • 沙箱测试 ││
|
||||||
|
│ │ • PoC 生成 • 误报过滤 ││
|
||||||
|
│ │ ────────┘ ││
|
||||||
|
│ │ 使用工具: vulnerability_validation, sandbox_exec, ││
|
||||||
|
│ │ verify_vulnerability ││
|
||||||
|
│ └────────────────────────────┬───────────────────────────────────┘│
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌────────────────────────────────────────────────────────────────┐│
|
||||||
|
│ │ Report Node (报告生成) ││
|
||||||
|
│ │ • 漏洞汇总 • 安全评分 ││
|
||||||
|
│ │ • 修复建议 • 统计分析 ││
|
||||||
|
│ └────────────────────────────┬───────────────────────────────────┘│
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ END │
|
||||||
|
│ │
|
||||||
|
└────────────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
状态流转:
|
||||||
|
• Recon → Analysis: 收集到入口点后进入分析
|
||||||
|
• Analysis → Analysis: 发现较多问题时继续迭代
|
||||||
|
• Analysis → Verification: 有发现时进入验证
|
||||||
|
• Verification → Analysis: 误报率高时回溯分析
|
||||||
|
• Verification → Report: 验证完成后生成报告
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心特性
|
||||||
|
|
||||||
|
### 1. LangGraph 状态图
|
||||||
|
|
||||||
|
- **声明式工作流**: 使用图结构定义 Agent 协作流程
|
||||||
|
- **状态自动合并**: `Annotated[List, operator.add]` 实现发现累加
|
||||||
|
- **条件路由**: 基于状态动态决定下一步
|
||||||
|
- **检查点恢复**: 支持任务中断后继续
|
||||||
|
|
||||||
|
### 2. Agent 工具集
|
||||||
|
|
||||||
|
#### 内置工具
|
||||||
|
|
||||||
|
| 工具 | 功能 | 节点 |
|
||||||
|
|------|------|------|
|
||||||
|
| `list_files` | 目录浏览 | Recon |
|
||||||
|
| `read_file` | 文件读取 | All |
|
||||||
|
| `search_code` | 代码搜索 | Analysis |
|
||||||
|
| `rag_query` | 语义检索 | Analysis |
|
||||||
|
| `security_search` | 安全代码搜索 | Analysis |
|
||||||
|
| `function_context` | 函数上下文 | Analysis |
|
||||||
|
| `pattern_match` | 模式匹配 | Analysis |
|
||||||
|
| `code_analysis` | LLM 分析 | Analysis |
|
||||||
|
| `dataflow_analysis` | 数据流追踪 | Analysis |
|
||||||
|
| `vulnerability_validation` | 漏洞验证 | Verification |
|
||||||
|
| `sandbox_exec` | 沙箱执行 | Verification |
|
||||||
|
| `verify_vulnerability` | 自动验证 | Verification |
|
||||||
|
|
||||||
|
#### 外部安全工具
|
||||||
|
|
||||||
|
| 工具 | 功能 | 适用场景 |
|
||||||
|
|------|------|----------|
|
||||||
|
| `semgrep_scan` | Semgrep 静态分析 | 多语言快速扫描 |
|
||||||
|
| `bandit_scan` | Bandit Python 扫描 | Python 安全分析 |
|
||||||
|
| `gitleaks_scan` | Gitleaks 密钥检测 | 密钥泄露检测 |
|
||||||
|
| `trufflehog_scan` | TruffleHog 扫描 | 深度密钥扫描 |
|
||||||
|
| `npm_audit` | npm 依赖审计 | Node.js 依赖漏洞 |
|
||||||
|
| `safety_scan` | Safety Python 审计 | Python 依赖漏洞 |
|
||||||
|
| `osv_scan` | OSV 漏洞扫描 | 多语言依赖漏洞 |
|
||||||
|
|
||||||
|
### 3. RAG 系统
|
||||||
|
|
||||||
|
- **代码分块**: 基于 Tree-sitter AST 的智能分块
|
||||||
|
- **向量存储**: ChromaDB 持久化
|
||||||
|
- **多语言支持**: Python, JavaScript, TypeScript, Java, Go, PHP, Rust 等
|
||||||
|
- **嵌入模型**: 独立配置,支持 OpenAI、Ollama、Cohere、HuggingFace
|
||||||
|
|
||||||
|
### 4. 安全沙箱
|
||||||
|
|
||||||
|
- **Docker 隔离**: 安全容器执行
|
||||||
|
- **资源限制**: 内存、CPU 限制
|
||||||
|
- **网络隔离**: 可配置网络访问
|
||||||
|
- **seccomp 策略**: 系统调用白名单
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
### 环境变量
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LLM 配置
|
||||||
|
DEFAULT_LLM_MODEL=gpt-4-turbo-preview
|
||||||
|
LLM_API_KEY=your-api-key
|
||||||
|
LLM_BASE_URL=https://api.openai.com/v1
|
||||||
|
|
||||||
|
# 嵌入模型配置(独立于 LLM)
|
||||||
|
EMBEDDING_PROVIDER=openai
|
||||||
|
EMBEDDING_MODEL=text-embedding-3-small
|
||||||
|
|
||||||
|
# 向量数据库
|
||||||
|
VECTOR_DB_PATH=./data/vectordb
|
||||||
|
|
||||||
|
# 沙箱配置
|
||||||
|
SANDBOX_IMAGE=deepaudit-sandbox:latest
|
||||||
|
SANDBOX_MEMORY_LIMIT=512m
|
||||||
|
SANDBOX_CPU_LIMIT=1.0
|
||||||
|
SANDBOX_NETWORK_DISABLED=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Agent 任务配置
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"target_vulnerabilities": [
|
||||||
|
"sql_injection",
|
||||||
|
"xss",
|
||||||
|
"command_injection",
|
||||||
|
"path_traversal",
|
||||||
|
"ssrf"
|
||||||
|
],
|
||||||
|
"verification_level": "sandbox",
|
||||||
|
"exclude_patterns": ["node_modules", "__pycache__", ".git"],
|
||||||
|
"max_iterations": 3,
|
||||||
|
"timeout_seconds": 1800
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 部署
|
||||||
|
|
||||||
|
### 1. 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# 可选:安装外部工具
|
||||||
|
pip install semgrep bandit safety
|
||||||
|
brew install gitleaks trufflehog osv-scanner # macOS
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 构建沙箱镜像
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/sandbox
|
||||||
|
./build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 数据库迁移
|
||||||
|
|
||||||
|
```bash
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 接口
|
||||||
|
|
||||||
|
### 创建任务
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/agent-tasks/
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"project_id": "xxx",
|
||||||
|
"name": "安全审计",
|
||||||
|
"target_vulnerabilities": ["sql_injection", "xss"],
|
||||||
|
"verification_level": "sandbox",
|
||||||
|
"max_iterations": 3
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 事件流
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/agent-tasks/{task_id}/events
|
||||||
|
Accept: text/event-stream
|
||||||
|
```
|
||||||
|
|
||||||
|
### 获取发现
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/agent-tasks/{task_id}/findings?verified_only=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### 任务摘要
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/agent-tasks/{task_id}/summary
|
||||||
|
```
|
||||||
|
|
||||||
|
## 支持的漏洞类型
|
||||||
|
|
||||||
|
| 类型 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `sql_injection` | SQL 注入 |
|
||||||
|
| `xss` | 跨站脚本 |
|
||||||
|
| `command_injection` | 命令注入 |
|
||||||
|
| `path_traversal` | 路径遍历 |
|
||||||
|
| `ssrf` | 服务端请求伪造 |
|
||||||
|
| `xxe` | XML 外部实体 |
|
||||||
|
| `insecure_deserialization` | 不安全反序列化 |
|
||||||
|
| `hardcoded_secret` | 硬编码密钥 |
|
||||||
|
| `weak_crypto` | 弱加密 |
|
||||||
|
| `authentication_bypass` | 认证绕过 |
|
||||||
|
| `authorization_bypass` | 授权绕过 |
|
||||||
|
| `idor` | 不安全直接对象引用 |
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/app/services/agent/
|
||||||
|
├── __init__.py # 模块导出
|
||||||
|
├── event_manager.py # 事件管理
|
||||||
|
├── agents/ # Agent 实现
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── base.py # Agent 基类
|
||||||
|
│ ├── recon.py # 信息收集 Agent
|
||||||
|
│ ├── analysis.py # 漏洞分析 Agent
|
||||||
|
│ ├── verification.py # 漏洞验证 Agent
|
||||||
|
│ └── orchestrator.py # 编排 Agent
|
||||||
|
├── graph/ # LangGraph 工作流
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── audit_graph.py # 状态定义和图构建
|
||||||
|
│ ├── nodes.py # 节点实现
|
||||||
|
│ └── runner.py # 执行器
|
||||||
|
├── tools/ # Agent 工具
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── base.py # 工具基类
|
||||||
|
│ ├── rag_tool.py # RAG 工具
|
||||||
|
│ ├── pattern_tool.py # 模式匹配工具
|
||||||
|
│ ├── code_analysis_tool.py
|
||||||
|
│ ├── file_tool.py # 文件操作
|
||||||
|
│ ├── sandbox_tool.py # 沙箱工具
|
||||||
|
│ └── external_tools.py # 外部安全工具
|
||||||
|
└── prompts/ # 系统提示词
|
||||||
|
├── __init__.py
|
||||||
|
└── system_prompts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 故障排除
|
||||||
|
|
||||||
|
### 沙箱镜像检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker images | grep deepaudit-sandbox
|
||||||
|
```
|
||||||
|
|
||||||
|
### 日志查看
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tail -f logs/agent.log
|
||||||
|
```
|
||||||
|
|
||||||
|
### 常见问题
|
||||||
|
|
||||||
|
1. **RAG 初始化失败**: 检查嵌入模型配置和 API Key
|
||||||
|
2. **沙箱启动失败**: 确保 Docker 正常运行
|
||||||
|
3. **外部工具不可用**: 检查 semgrep/bandit 等是否已安装
|
||||||
|
|
||||||
|
|
@ -54,3 +54,4 @@ EXPOSE 3000
|
||||||
|
|
||||||
ENTRYPOINT ["/docker-entrypoint.sh"]
|
ENTRYPOINT ["/docker-entrypoint.sh"]
|
||||||
CMD ["serve", "-s", "dist", "-l", "3000"]
|
CMD ["serve", "-s", "dist", "-l", "3000"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,3 +17,4 @@ export const ProtectedRoute = () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import RecycleBin from "@/pages/RecycleBin";
|
||||||
import InstantAnalysis from "@/pages/InstantAnalysis";
|
import InstantAnalysis from "@/pages/InstantAnalysis";
|
||||||
import AuditTasks from "@/pages/AuditTasks";
|
import AuditTasks from "@/pages/AuditTasks";
|
||||||
import TaskDetail from "@/pages/TaskDetail";
|
import TaskDetail from "@/pages/TaskDetail";
|
||||||
|
import AgentAudit from "@/pages/AgentAudit";
|
||||||
import AdminDashboard from "@/pages/AdminDashboard";
|
import AdminDashboard from "@/pages/AdminDashboard";
|
||||||
import LogsPage from "@/pages/LogsPage";
|
import LogsPage from "@/pages/LogsPage";
|
||||||
import Account from "@/pages/Account";
|
import Account from "@/pages/Account";
|
||||||
|
|
@ -56,6 +57,12 @@ const routes: RouteConfig[] = [
|
||||||
element: <TaskDetail />,
|
element: <TaskDetail />,
|
||||||
visible: false,
|
visible: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Agent审计",
|
||||||
|
path: "/agent-audit/:taskId",
|
||||||
|
element: <AgentAudit />,
|
||||||
|
visible: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "审计规则",
|
name: "审计规则",
|
||||||
path: "/audit-rules",
|
path: "/audit-rules",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
/**
|
||||||
|
* Agent 审计模式选择器
|
||||||
|
* 允许用户在快速审计和 Agent 审计模式之间选择
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { Bot, Zap, CheckCircle2, Clock, Shield, Code } from "lucide-react";
|
||||||
|
import { cn } from "@/shared/utils/utils";
|
||||||
|
|
||||||
|
export type AuditMode = "fast" | "agent";
|
||||||
|
|
||||||
|
interface AgentModeSelectorProps {
|
||||||
|
value: AuditMode;
|
||||||
|
onChange: (mode: AuditMode) => void;
|
||||||
|
disabled?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function AgentModeSelector({
|
||||||
|
value,
|
||||||
|
onChange,
|
||||||
|
disabled = false,
|
||||||
|
}: AgentModeSelectorProps) {
|
||||||
|
return (
|
||||||
|
<div className="space-y-3">
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
<Shield className="w-4 h-4 text-indigo-700" />
|
||||||
|
<span className="font-mono text-sm font-bold text-indigo-900 uppercase">
|
||||||
|
审计模式
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-2 gap-3">
|
||||||
|
{/* 快速审计模式 */}
|
||||||
|
<label
|
||||||
|
className={cn(
|
||||||
|
"relative flex flex-col p-4 border-2 cursor-pointer transition-all rounded-none",
|
||||||
|
value === "fast"
|
||||||
|
? "border-amber-500 bg-amber-50"
|
||||||
|
: "border-gray-300 hover:border-gray-400 bg-white",
|
||||||
|
disabled && "opacity-50 cursor-not-allowed"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="auditMode"
|
||||||
|
value="fast"
|
||||||
|
checked={value === "fast"}
|
||||||
|
onChange={() => onChange("fast")}
|
||||||
|
disabled={disabled}
|
||||||
|
className="sr-only"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
<div className="p-1.5 bg-amber-100 border border-amber-300">
|
||||||
|
<Zap className="w-4 h-4 text-amber-600" />
|
||||||
|
</div>
|
||||||
|
<span className="font-bold text-sm">快速审计</span>
|
||||||
|
{value === "fast" && (
|
||||||
|
<CheckCircle2 className="w-4 h-4 text-amber-600 ml-auto" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<ul className="text-xs text-gray-600 space-y-1 mb-3">
|
||||||
|
<li className="flex items-center gap-1">
|
||||||
|
<Clock className="w-3 h-3" />
|
||||||
|
速度快(分钟级)
|
||||||
|
</li>
|
||||||
|
<li className="flex items-center gap-1">
|
||||||
|
<Code className="w-3 h-3" />
|
||||||
|
逐文件 LLM 分析
|
||||||
|
</li>
|
||||||
|
<li className="flex items-center gap-1 text-gray-400">
|
||||||
|
<Shield className="w-3 h-3" />
|
||||||
|
无漏洞验证
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<div className="mt-auto pt-2 border-t border-gray-200">
|
||||||
|
<span className="text-[10px] uppercase tracking-wider text-gray-500 font-bold">
|
||||||
|
适合: CI/CD 集成、日常检查
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
{/* Agent 审计模式 */}
|
||||||
|
<label
|
||||||
|
className={cn(
|
||||||
|
"relative flex flex-col p-4 border-2 cursor-pointer transition-all rounded-none",
|
||||||
|
value === "agent"
|
||||||
|
? "border-purple-500 bg-purple-50"
|
||||||
|
: "border-gray-300 hover:border-gray-400 bg-white",
|
||||||
|
disabled && "opacity-50 cursor-not-allowed"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="auditMode"
|
||||||
|
value="agent"
|
||||||
|
checked={value === "agent"}
|
||||||
|
onChange={() => onChange("agent")}
|
||||||
|
disabled={disabled}
|
||||||
|
className="sr-only"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* 推荐标签 */}
|
||||||
|
<div className="absolute -top-2 -right-2 px-2 py-0.5 bg-purple-600 text-white text-[10px] font-bold uppercase">
|
||||||
|
推荐
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
<div className="p-1.5 bg-purple-100 border border-purple-300">
|
||||||
|
<Bot className="w-4 h-4 text-purple-600" />
|
||||||
|
</div>
|
||||||
|
<span className="font-bold text-sm">Agent 审计</span>
|
||||||
|
{value === "agent" && (
|
||||||
|
<CheckCircle2 className="w-4 h-4 text-purple-600 ml-auto" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<ul className="text-xs text-gray-600 space-y-1 mb-3">
|
||||||
|
<li className="flex items-center gap-1">
|
||||||
|
<Bot className="w-3 h-3" />
|
||||||
|
AI Agent 自主分析
|
||||||
|
</li>
|
||||||
|
<li className="flex items-center gap-1">
|
||||||
|
<Code className="w-3 h-3" />
|
||||||
|
跨文件关联 + RAG
|
||||||
|
</li>
|
||||||
|
<li className="flex items-center gap-1 text-purple-600 font-medium">
|
||||||
|
<Shield className="w-3 h-3" />
|
||||||
|
沙箱漏洞验证
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<div className="mt-auto pt-2 border-t border-gray-200">
|
||||||
|
<span className="text-[10px] uppercase tracking-wider text-gray-500 font-bold">
|
||||||
|
适合: 发版前审计、深度安全评估
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 模式说明 */}
|
||||||
|
{value === "agent" && (
|
||||||
|
<div className="p-3 bg-purple-50 border border-purple-200 text-xs text-purple-800 rounded-none">
|
||||||
|
<p className="font-bold mb-1">🤖 Agent 审计模式说明:</p>
|
||||||
|
<ul className="list-disc list-inside space-y-0.5 text-purple-700">
|
||||||
|
<li>AI Agent 会自主规划审计策略</li>
|
||||||
|
<li>使用 RAG 技术进行代码语义检索</li>
|
||||||
|
<li>在 Docker 沙箱中验证发现的漏洞</li>
|
||||||
|
<li>可生成可复现的 PoC(概念验证)代码</li>
|
||||||
|
<li>审计时间较长,但结果更全面准确</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,449 @@
|
||||||
|
/**
|
||||||
|
* 嵌入模型配置组件
|
||||||
|
* 独立于 LLM 配置,专门用于 Agent 审计的 RAG 系统
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
CardDescription,
|
||||||
|
CardHeader,
|
||||||
|
CardTitle,
|
||||||
|
} from "@/components/ui/card";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Label } from "@/components/ui/label";
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from "@/components/ui/select";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
import {
|
||||||
|
Brain,
|
||||||
|
Cpu,
|
||||||
|
Check,
|
||||||
|
X,
|
||||||
|
Loader2,
|
||||||
|
RefreshCw,
|
||||||
|
Server,
|
||||||
|
Key,
|
||||||
|
Zap,
|
||||||
|
Info,
|
||||||
|
} from "lucide-react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import { apiClient } from "@/shared/api/serverClient";
|
||||||
|
|
||||||
|
interface EmbeddingProvider {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
models: string[];
|
||||||
|
requires_api_key: boolean;
|
||||||
|
default_model: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface EmbeddingConfig {
|
||||||
|
provider: string;
|
||||||
|
model: string;
|
||||||
|
base_url: string | null;
|
||||||
|
dimensions: number;
|
||||||
|
batch_size: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TestResult {
|
||||||
|
success: boolean;
|
||||||
|
message: string;
|
||||||
|
dimensions?: number;
|
||||||
|
sample_embedding?: number[];
|
||||||
|
latency_ms?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function EmbeddingConfigPanel() {
|
||||||
|
const [providers, setProviders] = useState<EmbeddingProvider[]>([]);
|
||||||
|
const [currentConfig, setCurrentConfig] = useState<EmbeddingConfig | null>(null);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [saving, setSaving] = useState(false);
|
||||||
|
const [testing, setTesting] = useState(false);
|
||||||
|
const [testResult, setTestResult] = useState<TestResult | null>(null);
|
||||||
|
|
||||||
|
// 表单状态
|
||||||
|
const [selectedProvider, setSelectedProvider] = useState("");
|
||||||
|
const [selectedModel, setSelectedModel] = useState("");
|
||||||
|
const [apiKey, setApiKey] = useState("");
|
||||||
|
const [baseUrl, setBaseUrl] = useState("");
|
||||||
|
const [batchSize, setBatchSize] = useState(100);
|
||||||
|
|
||||||
|
// 加载数据
|
||||||
|
useEffect(() => {
|
||||||
|
loadData();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 当 provider 改变时更新模型
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedProvider) {
|
||||||
|
const provider = providers.find((p) => p.id === selectedProvider);
|
||||||
|
if (provider) {
|
||||||
|
setSelectedModel(provider.default_model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [selectedProvider, providers]);
|
||||||
|
|
||||||
|
const loadData = async () => {
|
||||||
|
try {
|
||||||
|
setLoading(true);
|
||||||
|
const [providersRes, configRes] = await Promise.all([
|
||||||
|
apiClient.get("/embedding/providers"),
|
||||||
|
apiClient.get("/embedding/config"),
|
||||||
|
]);
|
||||||
|
|
||||||
|
setProviders(providersRes.data);
|
||||||
|
setCurrentConfig(configRes.data);
|
||||||
|
|
||||||
|
// 设置表单默认值
|
||||||
|
if (configRes.data) {
|
||||||
|
setSelectedProvider(configRes.data.provider);
|
||||||
|
setSelectedModel(configRes.data.model);
|
||||||
|
setBaseUrl(configRes.data.base_url || "");
|
||||||
|
setBatchSize(configRes.data.batch_size);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
toast.error("加载配置失败");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSave = async () => {
|
||||||
|
if (!selectedProvider || !selectedModel) {
|
||||||
|
toast.error("请选择提供商和模型");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = providers.find((p) => p.id === selectedProvider);
|
||||||
|
if (provider?.requires_api_key && !apiKey) {
|
||||||
|
toast.error(`${provider.name} 需要 API Key`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
setSaving(true);
|
||||||
|
await apiClient.put("/embedding/config", {
|
||||||
|
provider: selectedProvider,
|
||||||
|
model: selectedModel,
|
||||||
|
api_key: apiKey || undefined,
|
||||||
|
base_url: baseUrl || undefined,
|
||||||
|
batch_size: batchSize,
|
||||||
|
});
|
||||||
|
|
||||||
|
toast.success("配置已保存");
|
||||||
|
await loadData();
|
||||||
|
} catch (error: any) {
|
||||||
|
toast.error(error.response?.data?.detail || "保存失败");
|
||||||
|
} finally {
|
||||||
|
setSaving(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleTest = async () => {
|
||||||
|
if (!selectedProvider || !selectedModel) {
|
||||||
|
toast.error("请选择提供商和模型");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
setTesting(true);
|
||||||
|
setTestResult(null);
|
||||||
|
|
||||||
|
const response = await apiClient.post("/embedding/test", {
|
||||||
|
provider: selectedProvider,
|
||||||
|
model: selectedModel,
|
||||||
|
api_key: apiKey || undefined,
|
||||||
|
base_url: baseUrl || undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
setTestResult(response.data);
|
||||||
|
|
||||||
|
if (response.data.success) {
|
||||||
|
toast.success("测试成功");
|
||||||
|
} else {
|
||||||
|
toast.error("测试失败");
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
setTestResult({
|
||||||
|
success: false,
|
||||||
|
message: error.response?.data?.detail || "测试失败",
|
||||||
|
});
|
||||||
|
toast.error("测试失败");
|
||||||
|
} finally {
|
||||||
|
setTesting(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const selectedProviderInfo = providers.find((p) => p.id === selectedProvider);
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center p-8">
|
||||||
|
<Loader2 className="w-6 h-6 animate-spin" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card className="border-2 border-black rounded-none shadow-[4px_4px_0px_0px_rgba(0,0,0,1)]">
|
||||||
|
<CardHeader className="border-b-2 border-black bg-purple-50">
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<div className="p-2 bg-purple-100 border-2 border-purple-300">
|
||||||
|
<Brain className="w-5 h-5 text-purple-600" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<CardTitle className="font-mono text-lg">嵌入模型配置</CardTitle>
|
||||||
|
<CardDescription>
|
||||||
|
用于 Agent 审计的 RAG 代码检索,独立于分析 LLM
|
||||||
|
</CardDescription>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardHeader>
|
||||||
|
|
||||||
|
<CardContent className="p-6 space-y-6">
|
||||||
|
{/* 当前配置状态 */}
|
||||||
|
{currentConfig && (
|
||||||
|
<div className="p-4 bg-gray-50 border-2 border-gray-200 space-y-2">
|
||||||
|
<div className="flex items-center gap-2 text-sm font-mono font-bold">
|
||||||
|
<Server className="w-4 h-4" />
|
||||||
|
当前配置
|
||||||
|
</div>
|
||||||
|
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||||
|
<div>
|
||||||
|
<span className="text-gray-500">提供商:</span>{" "}
|
||||||
|
<Badge variant="outline" className="ml-1">
|
||||||
|
{currentConfig.provider}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-gray-500">模型:</span>{" "}
|
||||||
|
<span className="font-mono">{currentConfig.model}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-gray-500">向量维度:</span>{" "}
|
||||||
|
<span className="font-mono">{currentConfig.dimensions}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-gray-500">批处理大小:</span>{" "}
|
||||||
|
<span className="font-mono">{currentConfig.batch_size}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Separator />
|
||||||
|
|
||||||
|
{/* 提供商选择 */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label className="font-mono font-bold">嵌入模型提供商</Label>
|
||||||
|
<Select value={selectedProvider} onValueChange={setSelectedProvider}>
|
||||||
|
<SelectTrigger className="border-2 border-black rounded-none">
|
||||||
|
<SelectValue placeholder="选择提供商" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent className="border-2 border-black rounded-none">
|
||||||
|
{providers.map((provider) => (
|
||||||
|
<SelectItem key={provider.id} value={provider.id}>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span>{provider.name}</span>
|
||||||
|
{provider.requires_api_key ? (
|
||||||
|
<Key className="w-3 h-3 text-amber-500" />
|
||||||
|
) : (
|
||||||
|
<Cpu className="w-3 h-3 text-green-500" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
|
||||||
|
{selectedProviderInfo && (
|
||||||
|
<p className="text-xs text-gray-500 flex items-center gap-1">
|
||||||
|
<Info className="w-3 h-3" />
|
||||||
|
{selectedProviderInfo.description}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 模型选择 */}
|
||||||
|
{selectedProviderInfo && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label className="font-mono font-bold">模型</Label>
|
||||||
|
<Select value={selectedModel} onValueChange={setSelectedModel}>
|
||||||
|
<SelectTrigger className="border-2 border-black rounded-none">
|
||||||
|
<SelectValue placeholder="选择模型" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent className="border-2 border-black rounded-none">
|
||||||
|
{selectedProviderInfo.models.map((model) => (
|
||||||
|
<SelectItem key={model} value={model}>
|
||||||
|
<span className="font-mono text-sm">{model}</span>
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* API Key */}
|
||||||
|
{selectedProviderInfo?.requires_api_key && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label className="font-mono font-bold">
|
||||||
|
API Key
|
||||||
|
<span className="text-red-500 ml-1">*</span>
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
value={apiKey}
|
||||||
|
onChange={(e) => setApiKey(e.target.value)}
|
||||||
|
placeholder="输入 API Key"
|
||||||
|
className="border-2 border-black rounded-none font-mono"
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-gray-500">
|
||||||
|
API Key 将安全存储,不会显示在页面上
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 自定义端点 */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label className="font-mono font-bold">
|
||||||
|
自定义 API 端点 <span className="text-gray-400">(可选)</span>
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
type="url"
|
||||||
|
value={baseUrl}
|
||||||
|
onChange={(e) => setBaseUrl(e.target.value)}
|
||||||
|
placeholder={
|
||||||
|
selectedProvider === "ollama"
|
||||||
|
? "http://localhost:11434"
|
||||||
|
: selectedProvider === "huggingface"
|
||||||
|
? "https://router.huggingface.co"
|
||||||
|
: selectedProvider === "cohere"
|
||||||
|
? "https://api.cohere.com/v2"
|
||||||
|
: selectedProvider === "jina"
|
||||||
|
? "https://api.jina.ai/v1"
|
||||||
|
: "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
className="border-2 border-black rounded-none font-mono"
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-gray-500">
|
||||||
|
用于 API 代理或自托管服务
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 批处理大小 */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label className="font-mono font-bold">批处理大小</Label>
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
value={batchSize}
|
||||||
|
onChange={(e) => setBatchSize(parseInt(e.target.value) || 100)}
|
||||||
|
min={1}
|
||||||
|
max={500}
|
||||||
|
className="border-2 border-black rounded-none font-mono w-32"
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-gray-500">
|
||||||
|
每批嵌入的文本数量,建议 50-100
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 测试结果 */}
|
||||||
|
{testResult && (
|
||||||
|
<div
|
||||||
|
className={`p-4 border-2 ${
|
||||||
|
testResult.success
|
||||||
|
? "border-green-500 bg-green-50"
|
||||||
|
: "border-red-500 bg-red-50"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
{testResult.success ? (
|
||||||
|
<Check className="w-5 h-5 text-green-600" />
|
||||||
|
) : (
|
||||||
|
<X className="w-5 h-5 text-red-600" />
|
||||||
|
)}
|
||||||
|
<span
|
||||||
|
className={`font-bold ${
|
||||||
|
testResult.success ? "text-green-700" : "text-red-700"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{testResult.success ? "测试成功" : "测试失败"}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<p className="text-sm">{testResult.message}</p>
|
||||||
|
{testResult.success && (
|
||||||
|
<div className="mt-2 text-xs text-gray-600 space-y-1">
|
||||||
|
<div>向量维度: {testResult.dimensions}</div>
|
||||||
|
<div>延迟: {testResult.latency_ms}ms</div>
|
||||||
|
{testResult.sample_embedding && (
|
||||||
|
<div>
|
||||||
|
示例向量: [{testResult.sample_embedding.map((v) => v.toFixed(4)).join(", ")}...]
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 操作按钮 */}
|
||||||
|
<div className="flex items-center gap-3 pt-4">
|
||||||
|
<Button
|
||||||
|
onClick={handleTest}
|
||||||
|
disabled={testing || !selectedProvider || !selectedModel}
|
||||||
|
variant="outline"
|
||||||
|
className="border-2 border-black rounded-none hover:bg-gray-100"
|
||||||
|
>
|
||||||
|
{testing ? (
|
||||||
|
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<Zap className="w-4 h-4 mr-2" />
|
||||||
|
)}
|
||||||
|
测试连接
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
onClick={handleSave}
|
||||||
|
disabled={saving || !selectedProvider || !selectedModel}
|
||||||
|
className="bg-purple-600 hover:bg-purple-700 border-2 border-black rounded-none"
|
||||||
|
>
|
||||||
|
{saving ? (
|
||||||
|
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<Check className="w-4 h-4 mr-2" />
|
||||||
|
)}
|
||||||
|
保存配置
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
onClick={loadData}
|
||||||
|
variant="ghost"
|
||||||
|
className="ml-auto"
|
||||||
|
>
|
||||||
|
<RefreshCw className="w-4 h-4" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 说明 */}
|
||||||
|
<div className="p-4 bg-blue-50 border-l-4 border-blue-500 text-sm">
|
||||||
|
<p className="font-bold mb-1">💡 关于嵌入模型</p>
|
||||||
|
<ul className="list-disc list-inside text-gray-600 space-y-1">
|
||||||
|
<li>嵌入模型用于 Agent 审计的代码语义搜索 (RAG)</li>
|
||||||
|
<li>与分析使用的 LLM 独立配置,互不影响</li>
|
||||||
|
<li>推荐使用 OpenAI text-embedding-3-small 或本地 Ollama</li>
|
||||||
|
<li>向量维度影响存储空间和检索精度</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import { useState, useEffect, useMemo, useRef } from "react";
|
import { useState, useEffect, useMemo, useRef } from "react";
|
||||||
|
import { useNavigate } from "react-router-dom";
|
||||||
import {
|
import {
|
||||||
Dialog,
|
Dialog,
|
||||||
DialogContent,
|
DialogContent,
|
||||||
|
|
@ -35,16 +36,19 @@ import {
|
||||||
Shield,
|
Shield,
|
||||||
Loader2,
|
Loader2,
|
||||||
Zap,
|
Zap,
|
||||||
|
Bot,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { api } from "@/shared/config/database";
|
import { api } from "@/shared/config/database";
|
||||||
import { getRuleSets, type AuditRuleSet } from "@/shared/api/rules";
|
import { getRuleSets, type AuditRuleSet } from "@/shared/api/rules";
|
||||||
import { getPromptTemplates, type PromptTemplate } from "@/shared/api/prompts";
|
import { getPromptTemplates, type PromptTemplate } from "@/shared/api/prompts";
|
||||||
|
import { createAgentTask } from "@/shared/api/agentTasks";
|
||||||
|
|
||||||
import { useProjects } from "./hooks/useTaskForm";
|
import { useProjects } from "./hooks/useTaskForm";
|
||||||
import { useZipFile, formatFileSize } from "./hooks/useZipFile";
|
import { useZipFile, formatFileSize } from "./hooks/useZipFile";
|
||||||
import TerminalProgressDialog from "./TerminalProgressDialog";
|
import TerminalProgressDialog from "./TerminalProgressDialog";
|
||||||
import FileSelectionDialog from "./FileSelectionDialog";
|
import FileSelectionDialog from "./FileSelectionDialog";
|
||||||
|
import AgentModeSelector, { type AuditMode } from "@/components/agent/AgentModeSelector";
|
||||||
|
|
||||||
import { runRepositoryAudit } from "@/features/projects/services/repoScan";
|
import { runRepositoryAudit } from "@/features/projects/services/repoScan";
|
||||||
import {
|
import {
|
||||||
|
|
@ -76,6 +80,7 @@ export default function CreateTaskDialog({
|
||||||
onTaskCreated,
|
onTaskCreated,
|
||||||
preselectedProjectId,
|
preselectedProjectId,
|
||||||
}: CreateTaskDialogProps) {
|
}: CreateTaskDialogProps) {
|
||||||
|
const navigate = useNavigate();
|
||||||
const [selectedProjectId, setSelectedProjectId] = useState<string>("");
|
const [selectedProjectId, setSelectedProjectId] = useState<string>("");
|
||||||
const [searchTerm, setSearchTerm] = useState("");
|
const [searchTerm, setSearchTerm] = useState("");
|
||||||
const [branch, setBranch] = useState("main");
|
const [branch, setBranch] = useState("main");
|
||||||
|
|
@ -90,6 +95,9 @@ export default function CreateTaskDialog({
|
||||||
const [showTerminal, setShowTerminal] = useState(false);
|
const [showTerminal, setShowTerminal] = useState(false);
|
||||||
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null);
|
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null);
|
||||||
|
|
||||||
|
// 审计模式
|
||||||
|
const [auditMode, setAuditMode] = useState<AuditMode>("agent");
|
||||||
|
|
||||||
// 规则集和提示词模板
|
// 规则集和提示词模板
|
||||||
const [ruleSets, setRuleSets] = useState<AuditRuleSet[]>([]);
|
const [ruleSets, setRuleSets] = useState<AuditRuleSet[]>([]);
|
||||||
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
|
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
|
||||||
|
|
@ -205,6 +213,31 @@ export default function CreateTaskDialog({
|
||||||
setCreating(true);
|
setCreating(true);
|
||||||
let taskId: string;
|
let taskId: string;
|
||||||
|
|
||||||
|
// Agent 审计模式
|
||||||
|
if (auditMode === "agent") {
|
||||||
|
const agentTask = await createAgentTask({
|
||||||
|
project_id: selectedProject.id,
|
||||||
|
name: `Agent审计-${selectedProject.name}`,
|
||||||
|
branch_name: isRepositoryProject(selectedProject) ? branch : undefined,
|
||||||
|
exclude_patterns: excludePatterns,
|
||||||
|
target_files: selectedFiles,
|
||||||
|
verification_level: "sandbox",
|
||||||
|
});
|
||||||
|
|
||||||
|
onOpenChange(false);
|
||||||
|
onTaskCreated();
|
||||||
|
toast.success("Agent 审计任务已创建");
|
||||||
|
|
||||||
|
// 导航到 Agent 审计页面
|
||||||
|
navigate(`/agent-audit/${agentTask.id}`);
|
||||||
|
|
||||||
|
setSelectedProjectId("");
|
||||||
|
setSelectedFiles(undefined);
|
||||||
|
setExcludePatterns(DEFAULT_EXCLUDES);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 快速审计模式(原有逻辑)
|
||||||
if (isZipProject(selectedProject)) {
|
if (isZipProject(selectedProject)) {
|
||||||
if (zipState.useStoredZip && zipState.storedZipInfo?.has_file) {
|
if (zipState.useStoredZip && zipState.storedZipInfo?.has_file) {
|
||||||
taskId = await scanStoredZipFile({
|
taskId = await scanStoredZipFile({
|
||||||
|
|
@ -339,6 +372,15 @@ export default function CreateTaskDialog({
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* 审计模式选择 */}
|
||||||
|
{selectedProject && (
|
||||||
|
<AgentModeSelector
|
||||||
|
value={auditMode}
|
||||||
|
onChange={setAuditMode}
|
||||||
|
disabled={creating}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* 配置区域 */}
|
{/* 配置区域 */}
|
||||||
{selectedProject && (
|
{selectedProject && (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
|
|
@ -589,10 +631,15 @@ export default function CreateTaskDialog({
|
||||||
<div className="animate-spin h-4 w-4 border-2 border-white border-t-transparent mr-2" />
|
<div className="animate-spin h-4 w-4 border-2 border-white border-t-transparent mr-2" />
|
||||||
启动中...
|
启动中...
|
||||||
</>
|
</>
|
||||||
|
) : auditMode === "agent" ? (
|
||||||
|
<>
|
||||||
|
<Bot className="w-4 h-4 mr-2" />
|
||||||
|
启动 Agent 审计
|
||||||
|
</>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Play className="w-4 h-4 mr-2" />
|
<Zap className="w-4 h-4 mr-2" />
|
||||||
开始扫描
|
开始快速扫描
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,11 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@
|
||||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||||
import {
|
import {
|
||||||
Settings, Save, RotateCcw, Eye, EyeOff, CheckCircle2, AlertCircle,
|
Settings, Save, RotateCcw, Eye, EyeOff, CheckCircle2, AlertCircle,
|
||||||
Info, Zap, Globe, PlayCircle, Loader2
|
Info, Zap, Globe, PlayCircle, Loader2, Brain
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { api } from "@/shared/api/database";
|
import { api } from "@/shared/api/database";
|
||||||
|
import EmbeddingConfig from "@/components/agent/EmbeddingConfig";
|
||||||
|
|
||||||
// LLM 提供商配置 - 2025年最新
|
// LLM 提供商配置 - 2025年最新
|
||||||
const LLM_PROVIDERS = [
|
const LLM_PROVIDERS = [
|
||||||
|
|
@ -246,10 +247,13 @@ export function SystemConfig() {
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<Tabs defaultValue="llm" className="w-full">
|
<Tabs defaultValue="llm" className="w-full">
|
||||||
<TabsList className="grid w-full grid-cols-3 bg-transparent border-2 border-black p-0 h-auto gap-0 mb-6">
|
<TabsList className="grid w-full grid-cols-4 bg-transparent border-2 border-black p-0 h-auto gap-0 mb-6">
|
||||||
<TabsTrigger value="llm" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
|
<TabsTrigger value="llm" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
|
||||||
<Zap className="w-3 h-3 mr-2" /> LLM 配置
|
<Zap className="w-3 h-3 mr-2" /> LLM 配置
|
||||||
</TabsTrigger>
|
</TabsTrigger>
|
||||||
|
<TabsTrigger value="embedding" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
|
||||||
|
<Brain className="w-3 h-3 mr-2" /> 嵌入模型
|
||||||
|
</TabsTrigger>
|
||||||
<TabsTrigger value="analysis" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
|
<TabsTrigger value="analysis" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
|
||||||
<Settings className="w-3 h-3 mr-2" /> 分析参数
|
<Settings className="w-3 h-3 mr-2" /> 分析参数
|
||||||
</TabsTrigger>
|
</TabsTrigger>
|
||||||
|
|
@ -388,6 +392,11 @@ export function SystemConfig() {
|
||||||
</div>
|
</div>
|
||||||
</TabsContent>
|
</TabsContent>
|
||||||
|
|
||||||
|
{/* 嵌入模型配置 */}
|
||||||
|
<TabsContent value="embedding" className="space-y-6">
|
||||||
|
<EmbeddingConfig />
|
||||||
|
</TabsContent>
|
||||||
|
|
||||||
{/* 分析参数 */}
|
{/* 分析参数 */}
|
||||||
<TabsContent value="analysis" className="space-y-6">
|
<TabsContent value="analysis" className="space-y-6">
|
||||||
<div className="retro-card bg-white border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)] p-6 space-y-6">
|
<div className="retro-card bg-white border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)] p-6 space-y-6">
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,579 @@
|
||||||
|
/**
|
||||||
|
* Agent 审计页面
|
||||||
|
* 机械终端风格的 AI Agent 审计界面
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useState, useEffect, useRef, useCallback } from "react";
|
||||||
|
import { useParams, useNavigate } from "react-router-dom";
|
||||||
|
import {
|
||||||
|
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
|
||||||
|
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
|
||||||
|
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
|
||||||
|
ArrowLeft, Download, ExternalLink
|
||||||
|
} from "lucide-react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import {
|
||||||
|
type AgentTask,
|
||||||
|
type AgentEvent,
|
||||||
|
type AgentFinding,
|
||||||
|
getAgentTask,
|
||||||
|
getAgentEvents,
|
||||||
|
getAgentFindings,
|
||||||
|
cancelAgentTask,
|
||||||
|
streamAgentEvents,
|
||||||
|
} from "@/shared/api/agentTasks";
|
||||||
|
|
||||||
|
// 事件类型图标映射
|
||||||
|
const eventTypeIcons: Record<string, React.ReactNode> = {
|
||||||
|
phase_start: <Zap className="w-3 h-3 text-cyan-400" />,
|
||||||
|
phase_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||||
|
thinking: <Cpu className="w-3 h-3 text-purple-400" />,
|
||||||
|
tool_call: <Code className="w-3 h-3 text-yellow-400" />,
|
||||||
|
tool_result: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||||
|
tool_error: <XCircle className="w-3 h-3 text-red-400" />,
|
||||||
|
finding_new: <Bug className="w-3 h-3 text-orange-400" />,
|
||||||
|
finding_verified: <Shield className="w-3 h-3 text-red-400" />,
|
||||||
|
info: <Activity className="w-3 h-3 text-blue-400" />,
|
||||||
|
warning: <AlertTriangle className="w-3 h-3 text-yellow-400" />,
|
||||||
|
error: <XCircle className="w-3 h-3 text-red-500" />,
|
||||||
|
progress: <RefreshCw className="w-3 h-3 text-cyan-400 animate-spin" />,
|
||||||
|
task_complete: <CheckCircle2 className="w-3 h-3 text-green-500" />,
|
||||||
|
task_error: <XCircle className="w-3 h-3 text-red-500" />,
|
||||||
|
task_cancel: <Square className="w-3 h-3 text-yellow-500" />,
|
||||||
|
};
|
||||||
|
|
||||||
|
// 事件类型颜色映射
|
||||||
|
const eventTypeColors: Record<string, string> = {
|
||||||
|
phase_start: "text-cyan-400 font-bold",
|
||||||
|
phase_complete: "text-green-400",
|
||||||
|
thinking: "text-purple-300",
|
||||||
|
tool_call: "text-yellow-300",
|
||||||
|
tool_result: "text-green-300",
|
||||||
|
tool_error: "text-red-400",
|
||||||
|
finding_new: "text-orange-300",
|
||||||
|
finding_verified: "text-red-300",
|
||||||
|
info: "text-gray-300",
|
||||||
|
warning: "text-yellow-300",
|
||||||
|
error: "text-red-400",
|
||||||
|
progress: "text-cyan-300",
|
||||||
|
task_complete: "text-green-400 font-bold",
|
||||||
|
task_error: "text-red-400 font-bold",
|
||||||
|
task_cancel: "text-yellow-400",
|
||||||
|
};
|
||||||
|
|
||||||
|
// 严重程度颜色
|
||||||
|
const severityColors: Record<string, string> = {
|
||||||
|
critical: "bg-red-900/50 border-red-500 text-red-300",
|
||||||
|
high: "bg-orange-900/50 border-orange-500 text-orange-300",
|
||||||
|
medium: "bg-yellow-900/50 border-yellow-500 text-yellow-300",
|
||||||
|
low: "bg-blue-900/50 border-blue-500 text-blue-300",
|
||||||
|
info: "bg-gray-900/50 border-gray-500 text-gray-300",
|
||||||
|
};
|
||||||
|
|
||||||
|
const severityIcons: Record<string, string> = {
|
||||||
|
critical: "🔴",
|
||||||
|
high: "🟠",
|
||||||
|
medium: "🟡",
|
||||||
|
low: "🟢",
|
||||||
|
info: "⚪",
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function AgentAuditPage() {
|
||||||
|
const { taskId } = useParams<{ taskId: string }>();
|
||||||
|
const navigate = useNavigate();
|
||||||
|
|
||||||
|
const [task, setTask] = useState<AgentTask | null>(null);
|
||||||
|
const [events, setEvents] = useState<AgentEvent[]>([]);
|
||||||
|
const [findings, setFindings] = useState<AgentFinding[]>([]);
|
||||||
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
|
const [isStreaming, setIsStreaming] = useState(false);
|
||||||
|
const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
|
||||||
|
|
||||||
|
const eventsEndRef = useRef<HTMLDivElement>(null);
|
||||||
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
|
// 是否完成
|
||||||
|
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
|
||||||
|
|
||||||
|
// 加载任务信息
|
||||||
|
const loadTask = useCallback(async () => {
|
||||||
|
if (!taskId) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const taskData = await getAgentTask(taskId);
|
||||||
|
setTask(taskData);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to load task:", error);
|
||||||
|
toast.error("加载任务失败");
|
||||||
|
}
|
||||||
|
}, [taskId]);
|
||||||
|
|
||||||
|
// 加载事件
|
||||||
|
const loadEvents = useCallback(async () => {
|
||||||
|
if (!taskId) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const eventsData = await getAgentEvents(taskId, { limit: 500 });
|
||||||
|
setEvents(eventsData);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to load events:", error);
|
||||||
|
}
|
||||||
|
}, [taskId]);
|
||||||
|
|
||||||
|
// 加载发现
|
||||||
|
const loadFindings = useCallback(async () => {
|
||||||
|
if (!taskId) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const findingsData = await getAgentFindings(taskId);
|
||||||
|
setFindings(findingsData);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to load findings:", error);
|
||||||
|
}
|
||||||
|
}, [taskId]);
|
||||||
|
|
||||||
|
// 初始化加载
|
||||||
|
useEffect(() => {
|
||||||
|
const init = async () => {
|
||||||
|
setIsLoading(true);
|
||||||
|
await Promise.all([loadTask(), loadEvents(), loadFindings()]);
|
||||||
|
setIsLoading(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
init();
|
||||||
|
}, [loadTask, loadEvents, loadFindings]);
|
||||||
|
|
||||||
|
// 事件流
|
||||||
|
useEffect(() => {
|
||||||
|
if (!taskId || isComplete || isLoading) return;
|
||||||
|
|
||||||
|
const startStreaming = async () => {
|
||||||
|
setIsStreaming(true);
|
||||||
|
abortControllerRef.current = new AbortController();
|
||||||
|
|
||||||
|
try {
|
||||||
|
const lastSequence = events.length > 0 ? Math.max(...events.map(e => e.sequence)) : 0;
|
||||||
|
|
||||||
|
for await (const event of streamAgentEvents(taskId, lastSequence, abortControllerRef.current.signal)) {
|
||||||
|
setEvents(prev => {
|
||||||
|
// 避免重复
|
||||||
|
if (prev.some(e => e.id === event.id)) return prev;
|
||||||
|
return [...prev, event];
|
||||||
|
});
|
||||||
|
|
||||||
|
// 如果是发现事件,刷新发现列表
|
||||||
|
if (event.event_type.startsWith("finding_")) {
|
||||||
|
loadFindings();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是结束事件,刷新任务状态
|
||||||
|
if (["task_complete", "task_error", "task_cancel"].includes(event.event_type)) {
|
||||||
|
loadTask();
|
||||||
|
loadFindings();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if ((error as Error).name !== "AbortError") {
|
||||||
|
console.error("Event stream error:", error);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setIsStreaming(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
startStreaming();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
abortControllerRef.current?.abort();
|
||||||
|
};
|
||||||
|
}, [taskId, isComplete, isLoading, loadTask, loadFindings]);
|
||||||
|
|
||||||
|
// 自动滚动
|
||||||
|
useEffect(() => {
|
||||||
|
eventsEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||||
|
}, [events]);
|
||||||
|
|
||||||
|
// 更新时间
|
||||||
|
useEffect(() => {
|
||||||
|
const interval = setInterval(() => {
|
||||||
|
setCurrentTime(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
|
||||||
|
}, 1000);
|
||||||
|
|
||||||
|
return () => clearInterval(interval);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 取消任务
|
||||||
|
const handleCancel = async () => {
|
||||||
|
if (!taskId) return;
|
||||||
|
|
||||||
|
if (!confirm("确定要取消此任务吗?已分析的结果将被保留。")) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await cancelAgentTask(taskId);
|
||||||
|
toast.success("任务已取消");
|
||||||
|
loadTask();
|
||||||
|
} catch (error) {
|
||||||
|
toast.error("取消失败");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen bg-[#0a0a0f] flex items-center justify-center">
|
||||||
|
<div className="text-center">
|
||||||
|
<Loader2 className="w-12 h-12 text-cyan-400 animate-spin mx-auto mb-4" />
|
||||||
|
<p className="text-gray-400 font-mono">正在加载...</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!task) {
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen bg-[#0a0a0f] flex items-center justify-center">
|
||||||
|
<div className="text-center">
|
||||||
|
<XCircle className="w-12 h-12 text-red-400 mx-auto mb-4" />
|
||||||
|
<p className="text-gray-400 font-mono">任务不存在</p>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="mt-4"
|
||||||
|
onClick={() => navigate("/tasks")}
|
||||||
|
>
|
||||||
|
返回任务列表
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen bg-[#0a0a0f] text-white font-mono">
|
||||||
|
{/* 顶部状态栏 */}
|
||||||
|
<div className="h-14 bg-[#12121a] border-b-2 border-cyan-900/50 flex items-center justify-between px-6">
|
||||||
|
<div className="flex items-center gap-4">
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => navigate(-1)}
|
||||||
|
className="text-gray-400 hover:text-white"
|
||||||
|
>
|
||||||
|
<ArrowLeft className="w-4 h-4 mr-1" />
|
||||||
|
返回
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<div className="p-1.5 bg-cyan-900/30 rounded border border-cyan-700/50">
|
||||||
|
<Bot className={`w-5 h-5 text-cyan-400 ${!isComplete && "animate-pulse"}`} />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-xs text-gray-500 block">AGENT_AUDIT</span>
|
||||||
|
<span className="text-sm font-bold tracking-wider text-cyan-400">
|
||||||
|
{task.name || `任务 ${task.id.slice(0, 8)}`}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-6">
|
||||||
|
{/* 阶段指示器 */}
|
||||||
|
<PhaseIndicator phase={task.current_phase} status={task.status} />
|
||||||
|
|
||||||
|
{/* 状态徽章 */}
|
||||||
|
<StatusBadge status={task.status} />
|
||||||
|
|
||||||
|
{/* 时间 */}
|
||||||
|
<span className="text-gray-500 text-sm">{currentTime}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex h-[calc(100vh-56px)]">
|
||||||
|
{/* 左侧:执行日志 */}
|
||||||
|
<div className="flex-1 p-4 flex flex-col min-w-0">
|
||||||
|
<div className="flex items-center justify-between mb-3">
|
||||||
|
<div className="flex items-center gap-2 text-xs text-cyan-400">
|
||||||
|
<Terminal className="w-4 h-4" />
|
||||||
|
<span className="uppercase tracking-wider">Execution Log</span>
|
||||||
|
{isStreaming && (
|
||||||
|
<span className="flex items-center gap-1 text-green-400">
|
||||||
|
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
|
||||||
|
LIVE
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<span className="text-xs text-gray-500">{events.length} events</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 终端窗口 */}
|
||||||
|
<div className="flex-1 bg-[#0d0d12] rounded-lg border border-gray-800 overflow-hidden relative">
|
||||||
|
{/* CRT 效果 */}
|
||||||
|
<div className="absolute inset-0 pointer-events-none z-10 opacity-[0.03]"
|
||||||
|
style={{
|
||||||
|
backgroundImage: "repeating-linear-gradient(0deg, transparent, transparent 1px, rgba(0, 255, 255, 0.03) 1px, rgba(0, 255, 255, 0.03) 2px)",
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<ScrollArea className="h-full">
|
||||||
|
<div className="p-4 space-y-1">
|
||||||
|
{events.map((event) => (
|
||||||
|
<EventLine key={event.id} event={event} />
|
||||||
|
))}
|
||||||
|
|
||||||
|
{/* 光标 */}
|
||||||
|
{!isComplete && (
|
||||||
|
<div className="flex items-center gap-2 mt-2">
|
||||||
|
<span className="text-gray-600 text-xs">{currentTime}</span>
|
||||||
|
<span className="text-cyan-400 animate-pulse">▌</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div ref={eventsEndRef} />
|
||||||
|
</div>
|
||||||
|
</ScrollArea>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 底部控制栏 */}
|
||||||
|
<div className="mt-3 flex items-center justify-between">
|
||||||
|
<div className="flex items-center gap-4">
|
||||||
|
{/* 进度 */}
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-xs text-gray-500">Progress</span>
|
||||||
|
<div className="w-32 h-2 bg-gray-800 rounded-full overflow-hidden">
|
||||||
|
<div
|
||||||
|
className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 transition-all duration-300"
|
||||||
|
style={{ width: `${task.progress_percentage}%` }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<span className="text-xs text-cyan-400">{task.progress_percentage.toFixed(0)}%</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Token 消耗 */}
|
||||||
|
{task.total_chunks > 0 && (
|
||||||
|
<div className="text-xs text-gray-500">
|
||||||
|
Chunks: <span className="text-gray-300">{task.total_chunks}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{!isComplete && (
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={handleCancel}
|
||||||
|
className="h-8 bg-transparent border-red-800 text-red-400 hover:bg-red-900/30 font-mono text-xs"
|
||||||
|
>
|
||||||
|
<Square className="w-3 h-3 mr-1" />
|
||||||
|
取消任务
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isComplete && (
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => navigate(`/tasks/${taskId}`)}
|
||||||
|
className="h-8 bg-transparent border-cyan-800 text-cyan-400 hover:bg-cyan-900/30 font-mono text-xs"
|
||||||
|
>
|
||||||
|
<ExternalLink className="w-3 h-3 mr-1" />
|
||||||
|
查看报告
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 右侧:发现面板 */}
|
||||||
|
<div className="w-80 bg-[#12121a] border-l border-gray-800 flex flex-col">
|
||||||
|
<div className="p-4 border-b border-gray-800">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="flex items-center gap-2 text-xs text-red-400">
|
||||||
|
<Shield className="w-4 h-4" />
|
||||||
|
<span className="uppercase tracking-wider">Findings</span>
|
||||||
|
</div>
|
||||||
|
<Badge variant="outline" className="bg-red-900/30 border-red-700 text-red-400">
|
||||||
|
{findings.length}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 严重程度统计 */}
|
||||||
|
<div className="flex items-center gap-3 mt-3 text-xs">
|
||||||
|
{task.critical_count > 0 && (
|
||||||
|
<span className="text-red-400">🔴 {task.critical_count}</span>
|
||||||
|
)}
|
||||||
|
{task.high_count > 0 && (
|
||||||
|
<span className="text-orange-400">🟠 {task.high_count}</span>
|
||||||
|
)}
|
||||||
|
{task.medium_count > 0 && (
|
||||||
|
<span className="text-yellow-400">🟡 {task.medium_count}</span>
|
||||||
|
)}
|
||||||
|
{task.low_count > 0 && (
|
||||||
|
<span className="text-blue-400">🟢 {task.low_count}</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<ScrollArea className="flex-1">
|
||||||
|
<div className="p-3 space-y-2">
|
||||||
|
{findings.length === 0 ? (
|
||||||
|
<div className="text-center text-gray-500 py-8">
|
||||||
|
<Search className="w-8 h-8 mx-auto mb-2 opacity-50" />
|
||||||
|
<p className="text-sm">暂无发现</p>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
findings.map((finding) => (
|
||||||
|
<FindingCard key={finding.id} finding={finding} />
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</ScrollArea>
|
||||||
|
|
||||||
|
{/* 评分 */}
|
||||||
|
{isComplete && (
|
||||||
|
<div className="p-4 border-t border-gray-800 space-y-2">
|
||||||
|
<div className="flex items-center justify-between text-xs">
|
||||||
|
<span className="text-gray-500">安全评分</span>
|
||||||
|
<span className={`font-bold ${
|
||||||
|
task.security_score >= 80 ? "text-green-400" :
|
||||||
|
task.security_score >= 60 ? "text-yellow-400" :
|
||||||
|
"text-red-400"
|
||||||
|
}`}>
|
||||||
|
{task.security_score.toFixed(0)}/100
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center justify-between text-xs">
|
||||||
|
<span className="text-gray-500">已验证</span>
|
||||||
|
<span className="text-cyan-400">{task.verified_count}/{task.findings_count}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 阶段指示器组件
|
||||||
|
function PhaseIndicator({ phase, status }: { phase: string | null; status: string }) {
|
||||||
|
const phases = ["planning", "indexing", "analysis", "verification", "reporting"];
|
||||||
|
const currentIndex = phase ? phases.indexOf(phase) : -1;
|
||||||
|
const isComplete = status === "completed";
|
||||||
|
const isFailed = status === "failed";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
{phases.map((p, idx) => {
|
||||||
|
const isActive = p === phase;
|
||||||
|
const isPast = isComplete || (currentIndex >= 0 && idx < currentIndex);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={p}
|
||||||
|
className={`w-2 h-2 rounded-full transition-all ${
|
||||||
|
isActive ? "bg-cyan-400 shadow-[0_0_8px_rgba(34,211,238,0.6)] animate-pulse" :
|
||||||
|
isPast ? "bg-cyan-600" :
|
||||||
|
isFailed ? "bg-red-900" :
|
||||||
|
"bg-gray-700"
|
||||||
|
}`}
|
||||||
|
title={p}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
{phase && (
|
||||||
|
<span className="ml-2 text-xs text-gray-400 uppercase">{phase}</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 状态徽章组件
|
||||||
|
function StatusBadge({ status }: { status: string }) {
|
||||||
|
const statusConfig: Record<string, { text: string; className: string }> = {
|
||||||
|
pending: { text: "PENDING", className: "bg-gray-800 text-gray-400 border-gray-600" },
|
||||||
|
initializing: { text: "INIT", className: "bg-blue-900/50 text-blue-400 border-blue-600 animate-pulse" },
|
||||||
|
planning: { text: "PLANNING", className: "bg-purple-900/50 text-purple-400 border-purple-600 animate-pulse" },
|
||||||
|
indexing: { text: "INDEXING", className: "bg-cyan-900/50 text-cyan-400 border-cyan-600 animate-pulse" },
|
||||||
|
analyzing: { text: "ANALYZING", className: "bg-yellow-900/50 text-yellow-400 border-yellow-600 animate-pulse" },
|
||||||
|
verifying: { text: "VERIFYING", className: "bg-orange-900/50 text-orange-400 border-orange-600 animate-pulse" },
|
||||||
|
completed: { text: "COMPLETED", className: "bg-green-900/50 text-green-400 border-green-600" },
|
||||||
|
failed: { text: "FAILED", className: "bg-red-900/50 text-red-400 border-red-600" },
|
||||||
|
cancelled: { text: "CANCELLED", className: "bg-yellow-900/50 text-yellow-400 border-yellow-600" },
|
||||||
|
};
|
||||||
|
|
||||||
|
const config = statusConfig[status] || statusConfig.pending;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Badge variant="outline" className={`${config.className} font-mono text-xs px-2`}>
|
||||||
|
{config.text}
|
||||||
|
</Badge>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 事件行组件
|
||||||
|
function EventLine({ event }: { event: AgentEvent }) {
|
||||||
|
const icon = eventTypeIcons[event.event_type] || <ChevronRight className="w-3 h-3 text-gray-500" />;
|
||||||
|
const colorClass = eventTypeColors[event.event_type] || "text-gray-400";
|
||||||
|
|
||||||
|
const timestamp = event.timestamp
|
||||||
|
? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false })
|
||||||
|
: "";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={`flex items-start gap-2 py-0.5 group hover:bg-white/5 px-1 rounded ${colorClass}`}>
|
||||||
|
<span className="text-gray-600 text-xs w-20 flex-shrink-0 group-hover:text-gray-500">
|
||||||
|
{timestamp}
|
||||||
|
</span>
|
||||||
|
<span className="flex-shrink-0 mt-0.5">{icon}</span>
|
||||||
|
<span className="flex-1 text-sm break-all">
|
||||||
|
{event.message}
|
||||||
|
{event.tool_duration_ms && (
|
||||||
|
<span className="text-gray-600 ml-2">({event.tool_duration_ms}ms)</span>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发现卡片组件
|
||||||
|
function FindingCard({ finding }: { finding: AgentFinding }) {
|
||||||
|
const colorClass = severityColors[finding.severity] || severityColors.info;
|
||||||
|
const icon = severityIcons[finding.severity] || "⚪";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={`p-3 rounded border-l-4 ${colorClass} transition-all hover:brightness-110`}>
|
||||||
|
<div className="flex items-start gap-2">
|
||||||
|
<span>{icon}</span>
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<p className="text-sm font-medium truncate">{finding.title}</p>
|
||||||
|
<p className="text-xs text-gray-400 mt-0.5">{finding.vulnerability_type}</p>
|
||||||
|
{finding.file_path && (
|
||||||
|
<p className="text-xs text-gray-500 mt-1 truncate" title={finding.file_path}>
|
||||||
|
<FileCode className="w-3 h-3 inline mr-1" />
|
||||||
|
{finding.file_path}:{finding.line_start}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-2 mt-2">
|
||||||
|
{finding.is_verified && (
|
||||||
|
<Badge variant="outline" className="text-[10px] px-1 py-0 bg-green-900/30 border-green-700 text-green-400">
|
||||||
|
<CheckCircle2 className="w-2.5 h-2.5 mr-0.5" />
|
||||||
|
已验证
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
{finding.has_poc && (
|
||||||
|
<Badge variant="outline" className="text-[10px] px-1 py-0 bg-red-900/30 border-red-700 text-red-400">
|
||||||
|
<Code className="w-2.5 h-2.5 mr-0.5" />
|
||||||
|
PoC
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,306 @@
|
||||||
|
/**
|
||||||
|
* Agent Tasks API
|
||||||
|
* Agent 审计任务相关的 API 调用
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { apiClient } from "./serverClient";
|
||||||
|
|
||||||
|
// ============ Types ============
|
||||||
|
|
||||||
|
export interface AgentTask {
|
||||||
|
id: string;
|
||||||
|
project_id: string;
|
||||||
|
name: string | null;
|
||||||
|
description: string | null;
|
||||||
|
task_type: string;
|
||||||
|
status: string;
|
||||||
|
current_phase: string | null;
|
||||||
|
current_step: string | null;
|
||||||
|
|
||||||
|
// 统计
|
||||||
|
total_files: number;
|
||||||
|
indexed_files: number;
|
||||||
|
analyzed_files: number;
|
||||||
|
total_chunks: number;
|
||||||
|
findings_count: number;
|
||||||
|
verified_count: number;
|
||||||
|
false_positive_count: number;
|
||||||
|
|
||||||
|
// 严重程度统计
|
||||||
|
critical_count: number;
|
||||||
|
high_count: number;
|
||||||
|
medium_count: number;
|
||||||
|
low_count: number;
|
||||||
|
|
||||||
|
// 评分
|
||||||
|
quality_score: number;
|
||||||
|
security_score: number;
|
||||||
|
|
||||||
|
// 时间
|
||||||
|
created_at: string;
|
||||||
|
started_at: string | null;
|
||||||
|
completed_at: string | null;
|
||||||
|
|
||||||
|
// 进度
|
||||||
|
progress_percentage: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AgentFinding {
|
||||||
|
id: string;
|
||||||
|
task_id: string;
|
||||||
|
vulnerability_type: string;
|
||||||
|
severity: string;
|
||||||
|
title: string;
|
||||||
|
description: string | null;
|
||||||
|
|
||||||
|
file_path: string | null;
|
||||||
|
line_start: number | null;
|
||||||
|
line_end: number | null;
|
||||||
|
code_snippet: string | null;
|
||||||
|
|
||||||
|
status: string;
|
||||||
|
is_verified: boolean;
|
||||||
|
has_poc: boolean;
|
||||||
|
poc_code: string | null;
|
||||||
|
|
||||||
|
suggestion: string | null;
|
||||||
|
fix_code: string | null;
|
||||||
|
ai_explanation: string | null;
|
||||||
|
ai_confidence: number | null;
|
||||||
|
|
||||||
|
created_at: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AgentEvent {
|
||||||
|
id: string;
|
||||||
|
task_id: string;
|
||||||
|
event_type: string;
|
||||||
|
phase: string | null;
|
||||||
|
message: string | null;
|
||||||
|
tool_name: string | null;
|
||||||
|
tool_input?: Record<string, unknown>;
|
||||||
|
tool_output?: Record<string, unknown>;
|
||||||
|
tool_duration_ms: number | null;
|
||||||
|
finding_id: string | null;
|
||||||
|
tokens_used?: number;
|
||||||
|
metadata?: Record<string, unknown>;
|
||||||
|
sequence: number;
|
||||||
|
timestamp: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CreateAgentTaskRequest {
|
||||||
|
project_id: string;
|
||||||
|
name?: string;
|
||||||
|
description?: string;
|
||||||
|
audit_scope?: Record<string, unknown>;
|
||||||
|
target_vulnerabilities?: string[];
|
||||||
|
verification_level?: "analysis_only" | "sandbox" | "generate_poc";
|
||||||
|
branch_name?: string;
|
||||||
|
exclude_patterns?: string[];
|
||||||
|
target_files?: string[];
|
||||||
|
max_iterations?: number;
|
||||||
|
token_budget?: number;
|
||||||
|
timeout_seconds?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AgentTaskSummary {
|
||||||
|
task_id: string;
|
||||||
|
status: string;
|
||||||
|
progress_percentage: number;
|
||||||
|
security_score: number;
|
||||||
|
quality_score: number;
|
||||||
|
statistics: {
|
||||||
|
total_files: number;
|
||||||
|
indexed_files: number;
|
||||||
|
analyzed_files: number;
|
||||||
|
total_chunks: number;
|
||||||
|
findings_count: number;
|
||||||
|
verified_count: number;
|
||||||
|
false_positive_count: number;
|
||||||
|
};
|
||||||
|
severity_distribution: {
|
||||||
|
critical: number;
|
||||||
|
high: number;
|
||||||
|
medium: number;
|
||||||
|
low: number;
|
||||||
|
};
|
||||||
|
vulnerability_types: Record<string, { total: number; verified: number }>;
|
||||||
|
duration_seconds: number | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ API Functions ============
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 审计任务
|
||||||
|
*/
|
||||||
|
export async function createAgentTask(data: CreateAgentTaskRequest): Promise<AgentTask> {
|
||||||
|
const response = await apiClient.post("/agent-tasks/", data);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 任务列表
|
||||||
|
*/
|
||||||
|
export async function getAgentTasks(params?: {
|
||||||
|
project_id?: string;
|
||||||
|
status?: string;
|
||||||
|
skip?: number;
|
||||||
|
limit?: number;
|
||||||
|
}): Promise<AgentTask[]> {
|
||||||
|
const response = await apiClient.get("/agent-tasks/", { params });
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 任务详情
|
||||||
|
*/
|
||||||
|
export async function getAgentTask(taskId: string): Promise<AgentTask> {
|
||||||
|
const response = await apiClient.get(`/agent-tasks/${taskId}`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 启动 Agent 任务
|
||||||
|
*/
|
||||||
|
export async function startAgentTask(taskId: string): Promise<{ message: string; task_id: string }> {
|
||||||
|
const response = await apiClient.post(`/agent-tasks/${taskId}/start`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 取消 Agent 任务
|
||||||
|
*/
|
||||||
|
export async function cancelAgentTask(taskId: string): Promise<{ message: string; task_id: string }> {
|
||||||
|
const response = await apiClient.post(`/agent-tasks/${taskId}/cancel`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 任务事件列表
|
||||||
|
*/
|
||||||
|
export async function getAgentEvents(
|
||||||
|
taskId: string,
|
||||||
|
params?: { after_sequence?: number; limit?: number }
|
||||||
|
): Promise<AgentEvent[]> {
|
||||||
|
const response = await apiClient.get(`/agent-tasks/${taskId}/events/list`, { params });
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 任务发现列表
|
||||||
|
*/
|
||||||
|
export async function getAgentFindings(
|
||||||
|
taskId: string,
|
||||||
|
params?: {
|
||||||
|
severity?: string;
|
||||||
|
vulnerability_type?: string;
|
||||||
|
is_verified?: boolean;
|
||||||
|
}
|
||||||
|
): Promise<AgentFinding[]> {
|
||||||
|
const response = await apiClient.get(`/agent-tasks/${taskId}/findings`, { params });
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取单个发现详情
|
||||||
|
*/
|
||||||
|
export async function getAgentFinding(taskId: string, findingId: string): Promise<AgentFinding> {
|
||||||
|
const response = await apiClient.get(`/agent-tasks/${taskId}/findings/${findingId}`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 更新发现状态
|
||||||
|
*/
|
||||||
|
export async function updateAgentFinding(
|
||||||
|
taskId: string,
|
||||||
|
findingId: string,
|
||||||
|
data: { status?: string }
|
||||||
|
): Promise<AgentFinding> {
|
||||||
|
const response = await apiClient.patch(`/agent-tasks/${taskId}/findings/${findingId}`, data);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务摘要
|
||||||
|
*/
|
||||||
|
export async function getAgentTaskSummary(taskId: string): Promise<AgentTaskSummary> {
|
||||||
|
const response = await apiClient.get(`/agent-tasks/${taskId}/summary`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 SSE 事件源
|
||||||
|
*/
|
||||||
|
export function createAgentEventSource(taskId: string, afterSequence = 0): EventSource {
|
||||||
|
const baseUrl = import.meta.env.VITE_API_URL || "";
|
||||||
|
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
|
||||||
|
|
||||||
|
// 注意:EventSource 不支持自定义 headers,需要通过 URL 参数或 cookie 传递认证
|
||||||
|
// 如果需要认证,可以考虑使用 fetch + ReadableStream 替代
|
||||||
|
return new EventSource(url, { withCredentials: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用 fetch 流式获取事件(支持自定义 headers)
|
||||||
|
*/
|
||||||
|
export async function* streamAgentEvents(
|
||||||
|
taskId: string,
|
||||||
|
afterSequence = 0,
|
||||||
|
signal?: AbortSignal
|
||||||
|
): AsyncGenerator<AgentEvent, void, unknown> {
|
||||||
|
const token = localStorage.getItem("auth_token");
|
||||||
|
const baseUrl = import.meta.env.VITE_API_URL || "";
|
||||||
|
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${token}`,
|
||||||
|
Accept: "text/event-stream",
|
||||||
|
},
|
||||||
|
signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to connect to event stream: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body?.getReader();
|
||||||
|
if (!reader) {
|
||||||
|
throw new Error("No response body");
|
||||||
|
}
|
||||||
|
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = "";
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
// 解析 SSE 格式
|
||||||
|
const lines = buffer.split("\n");
|
||||||
|
buffer = lines.pop() || "";
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.startsWith("data: ")) {
|
||||||
|
const data = line.slice(6);
|
||||||
|
try {
|
||||||
|
const event = JSON.parse(data) as AgentEvent;
|
||||||
|
yield event;
|
||||||
|
} catch {
|
||||||
|
// 忽略解析错误
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
reader.releaseLock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Loading…
Reference in New Issue