feat(agent): implement Agent audit module with LangGraph integration

- Introduce new Agent audit functionality for autonomous code security analysis and vulnerability verification.
- Add API endpoints for managing Agent tasks and configurations.
- Implement UI components for Agent mode selection and embedding model configuration.
- Enhance the overall architecture with a focus on RAG (Retrieval-Augmented Generation) for improved code semantic search.
- Create a sandbox environment for secure execution of vulnerability tests.
- Update documentation to include details on the new Agent audit features and usage instructions.
This commit is contained in:
lintsinghua 2025-12-11 19:09:10 +08:00
parent 7c9b9ea933
commit 9bc114af1f
68 changed files with 14072 additions and 5 deletions

View File

@ -58,3 +58,4 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -102,3 +102,4 @@ format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S datefmt = %H:%M:%S

View File

@ -89,3 +89,4 @@ else:
asyncio.run(run_migrations_online()) asyncio.run(run_migrations_online())

View File

@ -24,3 +24,4 @@ def downgrade() -> None:
${downgrades if downgrades else "pass"} ${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,255 @@
"""Add agent audit tables
Revision ID: 006_add_agent_tables
Revises: 5fc1cc05d5d0
Create Date: 2024-01-15 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '006_add_agent_tables'
down_revision = '5fc1cc05d5d0'
branch_labels = None
depends_on = None
def upgrade() -> None:
# 创建 agent_tasks 表
op.create_table(
'agent_tasks',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('project_id', sa.String(36), sa.ForeignKey('projects.id', ondelete='CASCADE'), nullable=False),
# 任务基本信息
sa.Column('name', sa.String(255), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('task_type', sa.String(50), default='agent_audit'),
# 任务配置
sa.Column('audit_scope', sa.JSON(), nullable=True),
sa.Column('target_vulnerabilities', sa.JSON(), nullable=True),
sa.Column('verification_level', sa.String(50), default='sandbox'),
# 分支信息
sa.Column('branch_name', sa.String(255), nullable=True),
# 排除模式
sa.Column('exclude_patterns', sa.JSON(), nullable=True),
# 文件范围
sa.Column('target_files', sa.JSON(), nullable=True),
# LLM 配置
sa.Column('llm_config', sa.JSON(), nullable=True),
# Agent 配置
sa.Column('agent_config', sa.JSON(), nullable=True),
sa.Column('max_iterations', sa.Integer(), default=50),
sa.Column('token_budget', sa.Integer(), default=100000),
sa.Column('timeout_seconds', sa.Integer(), default=1800),
# 状态
sa.Column('status', sa.String(20), default='pending'),
sa.Column('current_phase', sa.String(50), nullable=True),
sa.Column('current_step', sa.String(255), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
# 进度统计
sa.Column('total_files', sa.Integer(), default=0),
sa.Column('indexed_files', sa.Integer(), default=0),
sa.Column('analyzed_files', sa.Integer(), default=0),
sa.Column('total_chunks', sa.Integer(), default=0),
# Agent 统计
sa.Column('total_iterations', sa.Integer(), default=0),
sa.Column('tool_calls_count', sa.Integer(), default=0),
sa.Column('tokens_used', sa.Integer(), default=0),
# 发现统计
sa.Column('findings_count', sa.Integer(), default=0),
sa.Column('verified_count', sa.Integer(), default=0),
sa.Column('false_positive_count', sa.Integer(), default=0),
# 严重程度统计
sa.Column('critical_count', sa.Integer(), default=0),
sa.Column('high_count', sa.Integer(), default=0),
sa.Column('medium_count', sa.Integer(), default=0),
sa.Column('low_count', sa.Integer(), default=0),
# 质量评分
sa.Column('quality_score', sa.Float(), default=0.0),
sa.Column('security_score', sa.Float(), default=0.0),
# 审计计划
sa.Column('audit_plan', sa.JSON(), nullable=True),
# 时间戳
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), onupdate=sa.func.now()),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
# 创建者
sa.Column('created_by', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
)
# 创建 agent_tasks 索引
op.create_index('ix_agent_tasks_project_id', 'agent_tasks', ['project_id'])
op.create_index('ix_agent_tasks_status', 'agent_tasks', ['status'])
op.create_index('ix_agent_tasks_created_by', 'agent_tasks', ['created_by'])
op.create_index('ix_agent_tasks_created_at', 'agent_tasks', ['created_at'])
# 创建 agent_events 表
op.create_table(
'agent_events',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False),
# 事件信息
sa.Column('event_type', sa.String(50), nullable=False),
sa.Column('phase', sa.String(50), nullable=True),
# 事件内容
sa.Column('message', sa.Text(), nullable=True),
# 工具调用相关
sa.Column('tool_name', sa.String(100), nullable=True),
sa.Column('tool_input', sa.JSON(), nullable=True),
sa.Column('tool_output', sa.JSON(), nullable=True),
sa.Column('tool_duration_ms', sa.Integer(), nullable=True),
# 关联的发现
sa.Column('finding_id', sa.String(36), nullable=True),
# Token 消耗
sa.Column('tokens_used', sa.Integer(), default=0),
# 元数据
sa.Column('metadata', sa.JSON(), nullable=True),
# 序号
sa.Column('sequence', sa.Integer(), default=0),
# 时间戳
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# 创建 agent_events 索引
op.create_index('ix_agent_events_task_id', 'agent_events', ['task_id'])
op.create_index('ix_agent_events_event_type', 'agent_events', ['event_type'])
op.create_index('ix_agent_events_sequence', 'agent_events', ['sequence'])
op.create_index('ix_agent_events_created_at', 'agent_events', ['created_at'])
# 创建 agent_findings 表
op.create_table(
'agent_findings',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False),
# 漏洞基本信息
sa.Column('vulnerability_type', sa.String(100), nullable=False),
sa.Column('severity', sa.String(20), nullable=False),
sa.Column('title', sa.String(500), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
# 位置信息
sa.Column('file_path', sa.String(500), nullable=True),
sa.Column('line_start', sa.Integer(), nullable=True),
sa.Column('line_end', sa.Integer(), nullable=True),
sa.Column('column_start', sa.Integer(), nullable=True),
sa.Column('column_end', sa.Integer(), nullable=True),
sa.Column('function_name', sa.String(255), nullable=True),
sa.Column('class_name', sa.String(255), nullable=True),
# 代码片段
sa.Column('code_snippet', sa.Text(), nullable=True),
sa.Column('code_context', sa.Text(), nullable=True),
# 数据流信息
sa.Column('source', sa.Text(), nullable=True),
sa.Column('sink', sa.Text(), nullable=True),
sa.Column('dataflow_path', sa.JSON(), nullable=True),
# 验证信息
sa.Column('status', sa.String(30), default='new'),
sa.Column('is_verified', sa.Boolean(), default=False),
sa.Column('verification_method', sa.Text(), nullable=True),
sa.Column('verification_result', sa.JSON(), nullable=True),
sa.Column('verified_at', sa.DateTime(timezone=True), nullable=True),
# PoC
sa.Column('has_poc', sa.Boolean(), default=False),
sa.Column('poc_code', sa.Text(), nullable=True),
sa.Column('poc_description', sa.Text(), nullable=True),
sa.Column('poc_steps', sa.JSON(), nullable=True),
# 修复建议
sa.Column('suggestion', sa.Text(), nullable=True),
sa.Column('fix_code', sa.Text(), nullable=True),
sa.Column('fix_description', sa.Text(), nullable=True),
sa.Column('references', sa.JSON(), nullable=True),
# AI 解释
sa.Column('ai_explanation', sa.Text(), nullable=True),
sa.Column('ai_confidence', sa.Float(), nullable=True),
# XAI
sa.Column('xai_what', sa.Text(), nullable=True),
sa.Column('xai_why', sa.Text(), nullable=True),
sa.Column('xai_how', sa.Text(), nullable=True),
sa.Column('xai_impact', sa.Text(), nullable=True),
# 关联规则
sa.Column('matched_rule_code', sa.String(100), nullable=True),
sa.Column('matched_pattern', sa.Text(), nullable=True),
# CVSS 评分
sa.Column('cvss_score', sa.Float(), nullable=True),
sa.Column('cvss_vector', sa.String(100), nullable=True),
# 元数据
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
# 去重标识
sa.Column('fingerprint', sa.String(64), nullable=True),
# 时间戳
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), onupdate=sa.func.now()),
)
# 创建 agent_findings 索引
op.create_index('ix_agent_findings_task_id', 'agent_findings', ['task_id'])
op.create_index('ix_agent_findings_vulnerability_type', 'agent_findings', ['vulnerability_type'])
op.create_index('ix_agent_findings_severity', 'agent_findings', ['severity'])
op.create_index('ix_agent_findings_file_path', 'agent_findings', ['file_path'])
op.create_index('ix_agent_findings_status', 'agent_findings', ['status'])
op.create_index('ix_agent_findings_fingerprint', 'agent_findings', ['fingerprint'])
def downgrade() -> None:
# 删除索引和表
op.drop_index('ix_agent_findings_fingerprint', 'agent_findings')
op.drop_index('ix_agent_findings_status', 'agent_findings')
op.drop_index('ix_agent_findings_file_path', 'agent_findings')
op.drop_index('ix_agent_findings_severity', 'agent_findings')
op.drop_index('ix_agent_findings_vulnerability_type', 'agent_findings')
op.drop_index('ix_agent_findings_task_id', 'agent_findings')
op.drop_table('agent_findings')
op.drop_index('ix_agent_events_created_at', 'agent_events')
op.drop_index('ix_agent_events_sequence', 'agent_events')
op.drop_index('ix_agent_events_event_type', 'agent_events')
op.drop_index('ix_agent_events_task_id', 'agent_events')
op.drop_table('agent_events')
op.drop_index('ix_agent_tasks_created_at', 'agent_tasks')
op.drop_index('ix_agent_tasks_created_by', 'agent_tasks')
op.drop_index('ix_agent_tasks_status', 'agent_tasks')
op.drop_index('ix_agent_tasks_project_id', 'agent_tasks')
op.drop_table('agent_tasks')

View File

@ -1,5 +1,5 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules, agent_tasks, embedding_config
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
@ -12,3 +12,5 @@ api_router.include_router(config.router, prefix="/config", tags=["config"])
api_router.include_router(database.router, prefix="/database", tags=["database"]) api_router.include_router(database.router, prefix="/database", tags=["database"])
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
api_router.include_router(rules.router, prefix="/rules", tags=["rules"]) api_router.include_router(rules.router, prefix="/rules", tags=["rules"])
api_router.include_router(agent_tasks.router, prefix="/agent-tasks", tags=["agent-tasks"])
api_router.include_router(embedding_config.router, prefix="/embedding", tags=["embedding"])

View File

@ -0,0 +1,639 @@
"""
DeepAudit Agent 审计任务 API
基于 LangGraph Agent 审计
"""
import asyncio
import json
import logging
import os
import shutil
from typing import Any, List, Optional, Dict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from pydantic import BaseModel, Field
from app.api import deps
from app.db.session import get_db, async_session_factory
from app.models.agent_task import (
AgentTask, AgentEvent, AgentFinding,
AgentTaskStatus, AgentTaskPhase, AgentEventType,
VulnerabilitySeverity, FindingStatus,
)
from app.models.project import Project
from app.models.user import User
from app.services.agent import AgentRunner, EventManager, run_agent_task
logger = logging.getLogger(__name__)
router = APIRouter()
# 运行中的任务
_running_tasks: Dict[str, AgentRunner] = {}
# ============ Schemas ============
class AgentTaskCreate(BaseModel):
"""创建 Agent 任务请求"""
project_id: str = Field(..., description="项目 ID")
name: Optional[str] = Field(None, description="任务名称")
description: Optional[str] = Field(None, description="任务描述")
# 审计配置
audit_scope: Optional[dict] = Field(None, description="审计范围")
target_vulnerabilities: Optional[List[str]] = Field(
default=["sql_injection", "xss", "command_injection", "path_traversal", "ssrf"],
description="目标漏洞类型"
)
verification_level: str = Field(
"sandbox",
description="验证级别: analysis_only, sandbox, generate_poc"
)
# 分支
branch_name: Optional[str] = Field(None, description="分支名称")
# 排除模式
exclude_patterns: Optional[List[str]] = Field(
default=["node_modules", "__pycache__", ".git", "*.min.js"],
description="排除模式"
)
# 文件范围
target_files: Optional[List[str]] = Field(None, description="指定扫描的文件")
# Agent 配置
max_iterations: int = Field(3, ge=1, le=10, description="最大分析迭代次数")
timeout_seconds: int = Field(1800, ge=60, le=7200, description="超时时间(秒)")
class AgentTaskResponse(BaseModel):
"""Agent 任务响应"""
id: str
project_id: str
name: Optional[str]
description: Optional[str]
status: str
current_phase: Optional[str]
# 统计
total_findings: int = 0
verified_findings: int = 0
security_score: Optional[int] = None
# 时间
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
# 配置
config: Optional[dict] = None
# 错误信息
error_message: Optional[str] = None
class Config:
from_attributes = True
class AgentEventResponse(BaseModel):
"""Agent 事件响应"""
id: str
task_id: str
event_type: str
phase: Optional[str]
message: str
sequence: int
created_at: datetime
# 可选字段
tool_name: Optional[str] = None
tool_duration_ms: Optional[int] = None
progress_percent: Optional[float] = None
finding_id: Optional[str] = None
class Config:
from_attributes = True
class AgentFindingResponse(BaseModel):
"""Agent 发现响应"""
id: str
task_id: str
vulnerability_type: str
severity: str
title: str
description: Optional[str]
file_path: Optional[str]
line_start: Optional[int]
line_end: Optional[int]
code_snippet: Optional[str]
is_verified: bool
confidence: float
status: str
suggestion: Optional[str] = None
poc: Optional[dict] = None
created_at: datetime
class Config:
from_attributes = True
class TaskSummaryResponse(BaseModel):
"""任务摘要响应"""
task_id: str
status: str
security_score: Optional[int]
total_findings: int
verified_findings: int
severity_distribution: Dict[str, int]
vulnerability_types: Dict[str, int]
duration_seconds: Optional[int]
phases_completed: List[str]
# ============ 后台任务执行 ============
async def _execute_agent_task(task_id: str, project_root: str):
"""在后台执行 Agent 任务"""
async with async_session_factory() as db:
try:
# 获取任务
task = await db.get(AgentTask, task_id, options=[selectinload(AgentTask.project)])
if not task:
logger.error(f"Task {task_id} not found")
return
# 创建 Runner
runner = AgentRunner(db, task, project_root)
_running_tasks[task_id] = runner
# 执行
result = await runner.run()
logger.info(f"Task {task_id} completed: {result.get('success', False)}")
except Exception as e:
logger.error(f"Task {task_id} failed: {e}", exc_info=True)
# 更新任务状态
task = await db.get(AgentTask, task_id)
if task:
task.status = AgentTaskStatus.FAILED
task.error_message = str(e)
task.finished_at = datetime.now(timezone.utc)
await db.commit()
finally:
# 清理
_running_tasks.pop(task_id, None)
# ============ API Endpoints ============
@router.post("/", response_model=AgentTaskResponse)
async def create_agent_task(
request: AgentTaskCreate,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
创建并启动 Agent 审计任务
"""
# 验证项目
project = await db.get(Project, request.project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
if project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此项目")
# 创建任务
task = AgentTask(
id=str(uuid4()),
project_id=project.id,
name=request.name or f"Agent Audit - {datetime.now().strftime('%Y%m%d_%H%M%S')}",
description=request.description,
status=AgentTaskStatus.PENDING,
current_phase=AgentTaskPhase.PLANNING,
config={
"target_vulnerabilities": request.target_vulnerabilities,
"verification_level": request.verification_level,
"exclude_patterns": request.exclude_patterns,
"target_files": request.target_files,
"max_iterations": request.max_iterations,
"timeout_seconds": request.timeout_seconds,
},
created_by=current_user.id,
)
db.add(task)
await db.commit()
await db.refresh(task)
# 确定项目根目录
project_root = _get_project_root(project, task.id)
# 在后台启动任务
background_tasks.add_task(_execute_agent_task, task.id, project_root)
logger.info(f"Created agent task {task.id} for project {project.name}")
return task
@router.get("/", response_model=List[AgentTaskResponse])
async def list_agent_tasks(
project_id: Optional[str] = None,
status: Optional[str] = None,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取 Agent 任务列表
"""
# 获取用户的项目
projects_result = await db.execute(
select(Project.id).where(Project.owner_id == current_user.id)
)
user_project_ids = [p[0] for p in projects_result.fetchall()]
if not user_project_ids:
return []
# 构建查询
query = select(AgentTask).where(AgentTask.project_id.in_(user_project_ids))
if project_id:
query = query.where(AgentTask.project_id == project_id)
if status:
try:
status_enum = AgentTaskStatus(status)
query = query.where(AgentTask.status == status_enum)
except ValueError:
pass
query = query.order_by(AgentTask.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
tasks = result.scalars().all()
return tasks
@router.get("/{task_id}", response_model=AgentTaskResponse)
async def get_agent_task(
task_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取 Agent 任务详情
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
# 检查权限
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
return task
@router.post("/{task_id}/cancel")
async def cancel_agent_task(
task_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
取消 Agent 任务
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权操作此任务")
if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
raise HTTPException(status_code=400, detail="任务已结束,无法取消")
# 取消运行中的任务
runner = _running_tasks.get(task_id)
if runner:
runner.cancel()
# 更新状态
task.status = AgentTaskStatus.CANCELLED
task.finished_at = datetime.now(timezone.utc)
await db.commit()
return {"message": "任务已取消", "task_id": task_id}
@router.get("/{task_id}/events")
async def stream_agent_events(
task_id: str,
after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
):
"""
获取 Agent 事件流 (SSE)
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
async def event_generator():
"""生成 SSE 事件流"""
last_sequence = after_sequence
poll_interval = 0.5
max_idle = 300 # 5 分钟无事件后关闭
idle_time = 0
while True:
# 查询新事件
async with async_session_factory() as session:
result = await session.execute(
select(AgentEvent)
.where(AgentEvent.task_id == task_id)
.where(AgentEvent.sequence > last_sequence)
.order_by(AgentEvent.sequence)
.limit(50)
)
events = result.scalars().all()
# 获取任务状态
current_task = await session.get(AgentTask, task_id)
task_status = current_task.status if current_task else None
if events:
idle_time = 0
for event in events:
last_sequence = event.sequence
data = {
"id": event.id,
"type": event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type),
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
"message": event.message,
"sequence": event.sequence,
"timestamp": event.created_at.isoformat() if event.created_at else None,
"progress_percent": event.progress_percent,
"tool_name": event.tool_name,
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
idle_time += poll_interval
# 检查任务是否结束
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
yield f"data: {json.dumps({'type': 'task_end', 'status': task_status.value})}\n\n"
break
# 检查空闲超时
if idle_time >= max_idle:
yield f"data: {json.dumps({'type': 'timeout'})}\n\n"
break
await asyncio.sleep(poll_interval)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse])
async def list_agent_events(
task_id: str,
after_sequence: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取 Agent 事件列表
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
result = await db.execute(
select(AgentEvent)
.where(AgentEvent.task_id == task_id)
.where(AgentEvent.sequence > after_sequence)
.order_by(AgentEvent.sequence)
.limit(limit)
)
events = result.scalars().all()
return events
@router.get("/{task_id}/findings", response_model=List[AgentFindingResponse])
async def list_agent_findings(
task_id: str,
severity: Optional[str] = None,
verified_only: bool = False,
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取 Agent 发现列表
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
query = select(AgentFinding).where(AgentFinding.task_id == task_id)
if severity:
try:
sev_enum = VulnerabilitySeverity(severity)
query = query.where(AgentFinding.severity == sev_enum)
except ValueError:
pass
if verified_only:
query = query.where(AgentFinding.is_verified == True)
# 按严重程度排序
severity_order = {
VulnerabilitySeverity.CRITICAL: 0,
VulnerabilitySeverity.HIGH: 1,
VulnerabilitySeverity.MEDIUM: 2,
VulnerabilitySeverity.LOW: 3,
VulnerabilitySeverity.INFO: 4,
}
query = query.order_by(AgentFinding.severity, AgentFinding.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
findings = result.scalars().all()
return findings
@router.get("/{task_id}/summary", response_model=TaskSummaryResponse)
async def get_task_summary(
task_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取任务摘要
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
# 获取所有发现
result = await db.execute(
select(AgentFinding).where(AgentFinding.task_id == task_id)
)
findings = result.scalars().all()
# 统计
severity_distribution = {}
vulnerability_types = {}
verified_count = 0
for f in findings:
sev = f.severity.value if hasattr(f.severity, 'value') else str(f.severity)
vtype = f.vulnerability_type.value if hasattr(f.vulnerability_type, 'value') else str(f.vulnerability_type)
severity_distribution[sev] = severity_distribution.get(sev, 0) + 1
vulnerability_types[vtype] = vulnerability_types.get(vtype, 0) + 1
if f.is_verified:
verified_count += 1
# 计算持续时间
duration = None
if task.started_at and task.finished_at:
duration = int((task.finished_at - task.started_at).total_seconds())
# 获取已完成的阶段
phases_result = await db.execute(
select(AgentEvent.phase)
.where(AgentEvent.task_id == task_id)
.where(AgentEvent.event_type == AgentEventType.PHASE_COMPLETE)
.distinct()
)
phases = [p[0].value if p[0] and hasattr(p[0], 'value') else str(p[0]) for p in phases_result.fetchall() if p[0]]
return TaskSummaryResponse(
task_id=task_id,
status=task.status.value if hasattr(task.status, 'value') else str(task.status),
security_score=task.security_score,
total_findings=len(findings),
verified_findings=verified_count,
severity_distribution=severity_distribution,
vulnerability_types=vulnerability_types,
duration_seconds=duration,
phases_completed=phases,
)
@router.patch("/{task_id}/findings/{finding_id}/status")
async def update_finding_status(
task_id: str,
finding_id: str,
status: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
更新发现状态
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权操作")
finding = await db.get(AgentFinding, finding_id)
if not finding or finding.task_id != task_id:
raise HTTPException(status_code=404, detail="发现不存在")
try:
finding.status = FindingStatus(status)
except ValueError:
raise HTTPException(status_code=400, detail=f"无效的状态: {status}")
await db.commit()
return {"message": "状态已更新", "finding_id": finding_id, "status": status}
# ============ Helper Functions ============
def _get_project_root(project: Project, task_id: str) -> str:
"""
获取项目根目录
TODO: 实际实现中需要
- 对于 ZIP 项目解压到临时目录
- 对于 Git 仓库克隆到临时目录
"""
base_path = f"/tmp/deepaudit/{task_id}"
# 确保目录存在
os.makedirs(base_path, exist_ok=True)
# 如果项目有存储路径,复制过来
if hasattr(project, 'storage_path') and project.storage_path:
if os.path.exists(project.storage_path):
# 复制项目文件
shutil.copytree(project.storage_path, base_path, dirs_exist_ok=True)
return base_path

View File

@ -0,0 +1,396 @@
"""
嵌入模型配置 API
独立于 LLM 配置专门用于 RAG 系统的嵌入模型
使用 UserConfig.other_config 持久化存储
"""
import json
import uuid
from typing import Any, Optional, List
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api import deps
from app.models.user import User
from app.models.user_config import UserConfig
from app.core.config import settings
router = APIRouter()
# ============ Schemas ============
class EmbeddingProvider(BaseModel):
"""嵌入模型提供商"""
id: str
name: str
description: str
models: List[str]
requires_api_key: bool
default_model: str
class EmbeddingConfig(BaseModel):
"""嵌入模型配置"""
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface")
model: str = Field(description="模型名称")
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
dimensions: Optional[int] = Field(default=None, description="向量维度 (某些模型支持)")
batch_size: int = Field(default=100, description="批处理大小")
class EmbeddingConfigResponse(BaseModel):
"""配置响应"""
provider: str
model: str
base_url: Optional[str]
dimensions: int
batch_size: int
# 不返回 API Key
class TestEmbeddingRequest(BaseModel):
"""测试嵌入请求"""
provider: str
model: str
api_key: Optional[str] = None
base_url: Optional[str] = None
test_text: str = "这是一段测试文本,用于验证嵌入模型是否正常工作。"
class TestEmbeddingResponse(BaseModel):
"""测试嵌入响应"""
success: bool
message: str
dimensions: Optional[int] = None
sample_embedding: Optional[List[float]] = None # 前 5 个维度
latency_ms: Optional[int] = None
# ============ 提供商配置 ============
EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
EmbeddingProvider(
id="openai",
name="OpenAI",
description="OpenAI 官方嵌入模型,高质量、稳定",
models=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
requires_api_key=True,
default_model="text-embedding-3-small",
),
EmbeddingProvider(
id="azure",
name="Azure OpenAI",
description="Azure 托管的 OpenAI 嵌入模型",
models=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
requires_api_key=True,
default_model="text-embedding-3-small",
),
EmbeddingProvider(
id="ollama",
name="Ollama (本地)",
description="本地运行的开源嵌入模型 (使用 /api/embed 端点)",
models=[
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
"bge-m3",
"qwen3-embedding",
],
requires_api_key=False,
default_model="nomic-embed-text",
),
EmbeddingProvider(
id="cohere",
name="Cohere",
description="Cohere Embed v2 API (api.cohere.com/v2)",
models=[
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v3.0",
"embed-multilingual-light-v3.0",
"embed-v4.0",
],
requires_api_key=True,
default_model="embed-multilingual-v3.0",
),
EmbeddingProvider(
id="huggingface",
name="HuggingFace",
description="HuggingFace Inference Providers (router.huggingface.co)",
models=[
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/all-mpnet-base-v2",
"BAAI/bge-large-zh-v1.5",
"BAAI/bge-m3",
],
requires_api_key=True,
default_model="BAAI/bge-m3",
),
EmbeddingProvider(
id="jina",
name="Jina AI",
description="Jina AI 嵌入模型,代码嵌入效果好",
models=[
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-zh",
],
requires_api_key=True,
default_model="jina-embeddings-v2-base-code",
),
]
# ============ 数据库持久化存储 (异步) ============
EMBEDDING_CONFIG_KEY = "embedding_config"
async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig:
"""从数据库获取嵌入配置(异步)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == user_id)
)
user_config = result.scalar_one_or_none()
if user_config and user_config.other_config:
try:
other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
if embedding_data:
return EmbeddingConfig(
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
api_key=embedding_data.get("api_key"),
base_url=embedding_data.get("base_url"),
dimensions=embedding_data.get("dimensions"),
batch_size=embedding_data.get("batch_size", 100),
)
except (json.JSONDecodeError, AttributeError):
pass
# 返回默认配置
return EmbeddingConfig(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
api_key=settings.LLM_API_KEY,
base_url=settings.LLM_BASE_URL,
batch_size=100,
)
async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: EmbeddingConfig) -> None:
"""保存嵌入配置到数据库(异步)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == user_id)
)
user_config = result.scalar_one_or_none()
# 准备嵌入配置数据
embedding_data = {
"provider": config.provider,
"model": config.model,
"api_key": config.api_key,
"base_url": config.base_url,
"dimensions": config.dimensions,
"batch_size": config.batch_size,
}
if user_config:
# 更新现有配置
try:
other_config = json.loads(user_config.other_config) if user_config.other_config else {}
except (json.JSONDecodeError, TypeError):
other_config = {}
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
user_config.other_config = json.dumps(other_config)
else:
# 创建新配置
user_config = UserConfig(
id=str(uuid.uuid4()),
user_id=user_id,
llm_config="{}",
other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}),
)
db.add(user_config)
await db.commit()
# ============ API Endpoints ============
@router.get("/providers", response_model=List[EmbeddingProvider])
async def list_embedding_providers(
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取可用的嵌入模型提供商列表
"""
return EMBEDDING_PROVIDERS
@router.get("/config", response_model=EmbeddingConfigResponse)
async def get_current_config(
db: AsyncSession = Depends(deps.get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取当前嵌入模型配置从数据库读取
"""
config = await get_embedding_config_from_db(db, current_user.id)
# 获取维度
dimensions = _get_model_dimensions(config.provider, config.model)
return EmbeddingConfigResponse(
provider=config.provider,
model=config.model,
base_url=config.base_url,
dimensions=dimensions,
batch_size=config.batch_size,
)
@router.put("/config")
async def update_config(
config: EmbeddingConfig,
db: AsyncSession = Depends(deps.get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
更新嵌入模型配置持久化到数据库
"""
# 验证提供商
provider_ids = [p.id for p in EMBEDDING_PROVIDERS]
if config.provider not in provider_ids:
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
# 验证模型
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
if provider and config.model not in provider.models:
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
# 检查 API Key
if provider and provider.requires_api_key and not config.api_key:
raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key")
# 保存到数据库
await save_embedding_config_to_db(db, current_user.id, config)
return {"message": "配置已保存", "provider": config.provider, "model": config.model}
@router.post("/test", response_model=TestEmbeddingResponse)
async def test_embedding(
request: TestEmbeddingRequest,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
测试嵌入模型配置
"""
import time
try:
start_time = time.time()
# 创建临时嵌入服务
from app.services.rag.embeddings import EmbeddingService
service = EmbeddingService(
provider=request.provider,
model=request.model,
api_key=request.api_key,
base_url=request.base_url,
cache_enabled=False,
)
# 执行嵌入
embedding = await service.embed(request.test_text)
latency_ms = int((time.time() - start_time) * 1000)
return TestEmbeddingResponse(
success=True,
message=f"嵌入成功! 维度: {len(embedding)}",
dimensions=len(embedding),
sample_embedding=embedding[:5], # 返回前 5 维
latency_ms=latency_ms,
)
except Exception as e:
return TestEmbeddingResponse(
success=False,
message=f"嵌入失败: {str(e)}",
)
@router.get("/models/{provider}")
async def get_provider_models(
provider: str,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取指定提供商的模型列表
"""
provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None)
if not provider_info:
raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}")
return {
"provider": provider,
"models": provider_info.models,
"default_model": provider_info.default_model,
"requires_api_key": provider_info.requires_api_key,
}
def _get_model_dimensions(provider: str, model: str) -> int:
"""获取模型维度"""
dimensions_map = {
# OpenAI
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
# Ollama
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
# Cohere
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
# HuggingFace
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/all-mpnet-base-v2": 768,
"BAAI/bge-large-zh-v1.5": 1024,
"BAAI/bge-m3": 1024,
# Jina
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
}
return dimensions_map.get(model, 768)

View File

@ -209,3 +209,4 @@ async def remove_project_member(
return {"message": "成员已移除"} return {"message": "成员已移除"}

View File

@ -224,3 +224,4 @@ async def toggle_user_status(
return user return user

View File

@ -76,6 +76,32 @@ class Settings(BaseSettings):
# 输出语言配置 - 支持 zh-CN中文和 en-US英文 # 输出语言配置 - 支持 zh-CN中文和 en-US英文
OUTPUT_LANGUAGE: str = "zh-CN" OUTPUT_LANGUAGE: str = "zh-CN"
# ============ Agent 模块配置 ============
# 嵌入模型配置
EMBEDDING_PROVIDER: str = "openai" # openai, ollama, litellm
EMBEDDING_MODEL: str = "text-embedding-3-small"
# 向量数据库配置
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录
# Agent 配置
AGENT_MAX_ITERATIONS: int = 50 # Agent 最大迭代次数
AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间30分钟
# 沙箱配置
SANDBOX_IMAGE: str = "deepaudit-sandbox:latest" # 沙箱 Docker 镜像
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制
SANDBOX_TIMEOUT: int = 60 # 沙箱命令超时(秒)
SANDBOX_NETWORK_MODE: str = "none" # 沙箱网络模式 (none, bridge)
# RAG 配置
RAG_CHUNK_SIZE: int = 1500 # 代码块大小Token
RAG_CHUNK_OVERLAP: int = 50 # 代码块重叠Token
RAG_TOP_K: int = 10 # 检索返回数量
class Config: class Config:
case_sensitive = True case_sensitive = True

View File

@ -28,3 +28,4 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)

View File

@ -11,3 +11,4 @@ class Base:
return cls.__name__.lower() + "s" return cls.__name__.lower() + "s"

View File

@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from app.core.config import settings from app.core.config import settings
@ -16,3 +17,14 @@ async def get_db():
await session.close() await session.close()
@asynccontextmanager
async def async_session_factory():
"""Async context manager for creating database sessions"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()

View File

@ -1,8 +1,15 @@
from .user import User from .user import User
from .user_config import UserConfig
from .project import Project, ProjectMember from .project import Project, ProjectMember
from .audit import AuditTask, AuditIssue from .audit import AuditTask, AuditIssue
from .analysis import InstantAnalysis from .analysis import InstantAnalysis
from .prompt_template import PromptTemplate from .prompt_template import PromptTemplate
from .audit_rule import AuditRuleSet, AuditRule from .audit_rule import AuditRuleSet, AuditRule
from .agent_task import (
AgentTask, AgentEvent, AgentFinding,
AgentTaskStatus, AgentTaskPhase, AgentEventType,
VulnerabilitySeverity, VulnerabilityType, FindingStatus
)

View File

@ -0,0 +1,443 @@
"""
Agent 审计任务模型
支持 AI Agent 自主漏洞挖掘和验证
"""
import uuid
from datetime import datetime
from typing import Optional, List, TYPE_CHECKING
from sqlalchemy import (
Column, String, Integer, Float, Text, Boolean,
DateTime, ForeignKey, Enum as SQLEnum, JSON
)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.db.base import Base
if TYPE_CHECKING:
from .project import Project
class AgentTaskStatus:
"""Agent 任务状态"""
PENDING = "pending" # 等待执行
INITIALIZING = "initializing" # 初始化中
PLANNING = "planning" # 规划阶段
INDEXING = "indexing" # 索引阶段
ANALYZING = "analyzing" # 分析阶段
VERIFYING = "verifying" # 验证阶段
REPORTING = "reporting" # 报告生成
COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败
CANCELLED = "cancelled" # 已取消
PAUSED = "paused" # 已暂停
class AgentTaskPhase:
"""Agent 执行阶段"""
PLANNING = "planning"
INDEXING = "indexing"
RECONNAISSANCE = "reconnaissance"
ANALYSIS = "analysis"
VERIFICATION = "verification"
REPORTING = "reporting"
class AgentTask(Base):
"""Agent 审计任务"""
__tablename__ = "agent_tasks"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
# 任务基本信息
name = Column(String(255), nullable=True)
description = Column(Text, nullable=True)
task_type = Column(String(50), default="agent_audit")
# 任务配置
audit_scope = Column(JSON, nullable=True) # 审计范围配置
target_vulnerabilities = Column(JSON, nullable=True) # 目标漏洞类型
verification_level = Column(String(50), default="sandbox") # analysis_only, sandbox, generate_poc
# 分支信息(仓库项目)
branch_name = Column(String(255), nullable=True)
# 排除模式
exclude_patterns = Column(JSON, nullable=True)
# 文件范围
target_files = Column(JSON, nullable=True) # 指定扫描的文件列表
# LLM 配置
llm_config = Column(JSON, nullable=True) # LLM 配置
# Agent 配置
agent_config = Column(JSON, nullable=True) # Agent 特定配置
max_iterations = Column(Integer, default=50) # 最大迭代次数
token_budget = Column(Integer, default=100000) # Token 预算
timeout_seconds = Column(Integer, default=1800) # 超时时间(秒)
# 状态
status = Column(String(20), default=AgentTaskStatus.PENDING)
current_phase = Column(String(50), nullable=True)
current_step = Column(String(255), nullable=True) # 当前执行步骤描述
error_message = Column(Text, nullable=True)
# 进度统计
total_files = Column(Integer, default=0)
indexed_files = Column(Integer, default=0)
analyzed_files = Column(Integer, default=0)
total_chunks = Column(Integer, default=0) # 代码块总数
# Agent 统计
total_iterations = Column(Integer, default=0) # Agent 迭代次数
tool_calls_count = Column(Integer, default=0) # 工具调用次数
tokens_used = Column(Integer, default=0) # 已使用 Token 数
# 发现统计
findings_count = Column(Integer, default=0) # 发现总数
verified_count = Column(Integer, default=0) # 已验证数
false_positive_count = Column(Integer, default=0) # 误报数
# 严重程度统计
critical_count = Column(Integer, default=0)
high_count = Column(Integer, default=0)
medium_count = Column(Integer, default=0)
low_count = Column(Integer, default=0)
# 质量评分
quality_score = Column(Float, default=0.0)
security_score = Column(Float, default=0.0)
# 审计计划
audit_plan = Column(JSON, nullable=True) # Agent 生成的审计计划
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
# 创建者
created_by = Column(String(36), ForeignKey("users.id"), nullable=False)
# 关联关系
project = relationship("Project", back_populates="agent_tasks")
events = relationship("AgentEvent", back_populates="task", cascade="all, delete-orphan", order_by="AgentEvent.created_at")
findings = relationship("AgentFinding", back_populates="task", cascade="all, delete-orphan")
def __repr__(self):
return f"<AgentTask {self.id} - {self.status}>"
@property
def progress_percentage(self) -> float:
"""计算进度百分比"""
if self.status == AgentTaskStatus.COMPLETED:
return 100.0
if self.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
return 0.0
phase_weights = {
AgentTaskPhase.PLANNING: 5,
AgentTaskPhase.INDEXING: 15,
AgentTaskPhase.RECONNAISSANCE: 10,
AgentTaskPhase.ANALYSIS: 50,
AgentTaskPhase.VERIFICATION: 15,
AgentTaskPhase.REPORTING: 5,
}
completed_weight = 0
current_found = False
for phase, weight in phase_weights.items():
if phase == self.current_phase:
current_found = True
# 估算当前阶段进度
if phase == AgentTaskPhase.INDEXING and self.total_files > 0:
completed_weight += weight * (self.indexed_files / self.total_files)
elif phase == AgentTaskPhase.ANALYSIS and self.total_files > 0:
completed_weight += weight * (self.analyzed_files / self.total_files)
else:
completed_weight += weight * 0.5
break
elif not current_found:
completed_weight += weight
return min(completed_weight, 99.0)
class AgentEventType:
"""Agent 事件类型"""
# 系统事件
TASK_START = "task_start"
TASK_COMPLETE = "task_complete"
TASK_ERROR = "task_error"
TASK_CANCEL = "task_cancel"
# 阶段事件
PHASE_START = "phase_start"
PHASE_COMPLETE = "phase_complete"
# Agent 思考
THINKING = "thinking"
PLANNING = "planning"
DECISION = "decision"
# 工具调用
TOOL_CALL = "tool_call"
TOOL_RESULT = "tool_result"
TOOL_ERROR = "tool_error"
# RAG 相关
RAG_QUERY = "rag_query"
RAG_RESULT = "rag_result"
# 发现相关
FINDING_NEW = "finding_new"
FINDING_UPDATE = "finding_update"
FINDING_VERIFIED = "finding_verified"
FINDING_FALSE_POSITIVE = "finding_false_positive"
# 沙箱相关
SANDBOX_START = "sandbox_start"
SANDBOX_EXEC = "sandbox_exec"
SANDBOX_RESULT = "sandbox_result"
SANDBOX_ERROR = "sandbox_error"
# 进度
PROGRESS = "progress"
# 日志
INFO = "info"
WARNING = "warning"
ERROR = "error"
DEBUG = "debug"
class AgentEvent(Base):
"""Agent 执行事件(用于实时日志和回放)"""
__tablename__ = "agent_events"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
# 事件信息
event_type = Column(String(50), nullable=False, index=True)
phase = Column(String(50), nullable=True)
# 事件内容
message = Column(Text, nullable=True)
# 工具调用相关
tool_name = Column(String(100), nullable=True)
tool_input = Column(JSON, nullable=True)
tool_output = Column(JSON, nullable=True)
tool_duration_ms = Column(Integer, nullable=True) # 工具执行时长(毫秒)
# 关联的发现
finding_id = Column(String(36), nullable=True)
# Token 消耗
tokens_used = Column(Integer, default=0)
# 元数据
event_metadata = Column(JSON, nullable=True)
# 序号(用于排序)
sequence = Column(Integer, default=0, index=True)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
# 关联关系
task = relationship("AgentTask", back_populates="events")
def __repr__(self):
return f"<AgentEvent {self.event_type} - {self.message[:50] if self.message else ''}>"
def to_sse_dict(self) -> dict:
"""转换为 SSE 事件格式"""
return {
"id": self.id,
"type": self.event_type,
"phase": self.phase,
"message": self.message,
"tool_name": self.tool_name,
"tool_input": self.tool_input,
"tool_output": self.tool_output,
"tool_duration_ms": self.tool_duration_ms,
"finding_id": self.finding_id,
"tokens_used": self.tokens_used,
"metadata": self.event_metadata,
"sequence": self.sequence,
"timestamp": self.created_at.isoformat() if self.created_at else None,
}
class VulnerabilitySeverity:
"""漏洞严重程度"""
CRITICAL = "critical"
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
INFO = "info"
class VulnerabilityType:
"""漏洞类型"""
SQL_INJECTION = "sql_injection"
NOSQL_INJECTION = "nosql_injection"
XSS = "xss"
COMMAND_INJECTION = "command_injection"
CODE_INJECTION = "code_injection"
PATH_TRAVERSAL = "path_traversal"
FILE_INCLUSION = "file_inclusion"
SSRF = "ssrf"
XXE = "xxe"
DESERIALIZATION = "deserialization"
AUTH_BYPASS = "auth_bypass"
IDOR = "idor"
SENSITIVE_DATA_EXPOSURE = "sensitive_data_exposure"
HARDCODED_SECRET = "hardcoded_secret"
WEAK_CRYPTO = "weak_crypto"
RACE_CONDITION = "race_condition"
BUSINESS_LOGIC = "business_logic"
MEMORY_CORRUPTION = "memory_corruption"
OTHER = "other"
class FindingStatus:
"""发现状态"""
NEW = "new" # 新发现
ANALYZING = "analyzing" # 分析中
VERIFIED = "verified" # 已验证
FALSE_POSITIVE = "false_positive" # 误报
NEEDS_REVIEW = "needs_review" # 需要人工审核
FIXED = "fixed" # 已修复
WONT_FIX = "wont_fix" # 不修复
DUPLICATE = "duplicate" # 重复
class AgentFinding(Base):
"""Agent 发现的漏洞"""
__tablename__ = "agent_findings"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
# 漏洞基本信息
vulnerability_type = Column(String(100), nullable=False, index=True)
severity = Column(String(20), nullable=False, index=True)
title = Column(String(500), nullable=False)
description = Column(Text, nullable=True)
# 位置信息
file_path = Column(String(500), nullable=True, index=True)
line_start = Column(Integer, nullable=True)
line_end = Column(Integer, nullable=True)
column_start = Column(Integer, nullable=True)
column_end = Column(Integer, nullable=True)
function_name = Column(String(255), nullable=True)
class_name = Column(String(255), nullable=True)
# 代码片段
code_snippet = Column(Text, nullable=True)
code_context = Column(Text, nullable=True) # 更多上下文
# 数据流信息
source = Column(Text, nullable=True) # 污点源
sink = Column(Text, nullable=True) # 危险函数
dataflow_path = Column(JSON, nullable=True) # 数据流路径
# 验证信息
status = Column(String(30), default=FindingStatus.NEW, index=True)
is_verified = Column(Boolean, default=False)
verification_method = Column(Text, nullable=True)
verification_result = Column(JSON, nullable=True)
verified_at = Column(DateTime(timezone=True), nullable=True)
# PoC
has_poc = Column(Boolean, default=False)
poc_code = Column(Text, nullable=True)
poc_description = Column(Text, nullable=True)
poc_steps = Column(JSON, nullable=True) # 复现步骤
# 修复建议
suggestion = Column(Text, nullable=True)
fix_code = Column(Text, nullable=True)
fix_description = Column(Text, nullable=True)
references = Column(JSON, nullable=True) # 参考链接 CWE, OWASP 等
# AI 解释
ai_explanation = Column(Text, nullable=True)
ai_confidence = Column(Float, nullable=True) # AI 置信度 0-1
# XAI (可解释AI)
xai_what = Column(Text, nullable=True)
xai_why = Column(Text, nullable=True)
xai_how = Column(Text, nullable=True)
xai_impact = Column(Text, nullable=True)
# 关联规则
matched_rule_code = Column(String(100), nullable=True)
matched_pattern = Column(Text, nullable=True)
# CVSS 评分(可选)
cvss_score = Column(Float, nullable=True)
cvss_vector = Column(String(100), nullable=True)
# 元数据
finding_metadata = Column(JSON, nullable=True)
tags = Column(JSON, nullable=True)
# 去重标识
fingerprint = Column(String(64), nullable=True, index=True) # 用于去重的指纹
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# 关联关系
task = relationship("AgentTask", back_populates="findings")
def __repr__(self):
return f"<AgentFinding {self.vulnerability_type} - {self.severity} - {self.file_path}>"
def generate_fingerprint(self) -> str:
"""生成去重指纹"""
import hashlib
components = [
self.vulnerability_type or "",
self.file_path or "",
str(self.line_start or 0),
self.function_name or "",
(self.code_snippet or "")[:200],
]
content = "|".join(components)
return hashlib.sha256(content.encode()).hexdigest()[:16]
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"task_id": self.task_id,
"vulnerability_type": self.vulnerability_type,
"severity": self.severity,
"title": self.title,
"description": self.description,
"file_path": self.file_path,
"line_start": self.line_start,
"line_end": self.line_end,
"code_snippet": self.code_snippet,
"status": self.status,
"is_verified": self.is_verified,
"has_poc": self.has_poc,
"poc_code": self.poc_code,
"suggestion": self.suggestion,
"fix_code": self.fix_code,
"ai_explanation": self.ai_explanation,
"ai_confidence": self.ai_confidence,
"created_at": self.created_at.isoformat() if self.created_at else None,
}

View File

@ -23,3 +23,4 @@ class InstantAnalysis(Base):
user = relationship("User", backref="instant_analyses") user = relationship("User", backref="instant_analyses")

View File

@ -31,6 +31,7 @@ class Project(Base):
owner = relationship("User", backref="projects") owner = relationship("User", backref="projects")
members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan") members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan")
tasks = relationship("AuditTask", back_populates="project", cascade="all, delete-orphan") tasks = relationship("AuditTask", back_populates="project", cascade="all, delete-orphan")
agent_tasks = relationship("AgentTask", back_populates="project", cascade="all, delete-orphan")
class ProjectMember(Base): class ProjectMember(Base):
__tablename__ = "project_members" __tablename__ = "project_members"
@ -49,3 +50,4 @@ class ProjectMember(Base):
user = relationship("User", backref="project_memberships") user = relationship("User", backref="project_memberships")

View File

@ -24,3 +24,4 @@ class User(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@ -29,3 +29,4 @@ class UserConfig(Base):
user = relationship("User", backref="config") user = relationship("User", backref="config")

View File

@ -9,3 +9,4 @@ class TokenPayload(BaseModel):
sub: Optional[str] = None sub: Optional[str] = None

View File

@ -40,3 +40,4 @@ class UserListResponse(BaseModel):
limit: int limit: int

View File

@ -0,0 +1,58 @@
"""
DeepAudit Agent 服务模块
基于 LangGraph AI Agent 代码安全审计
架构:
LangGraph 状态图工作流
START Recon Analysis Verification Report END
节点:
- Recon: 信息收集 (项目结构技术栈入口点)
- Analysis: 漏洞分析 (静态分析RAG模式匹配)
- Verification: 漏洞验证 (LLM 验证沙箱测试)
- Report: 报告生成
"""
# 从 graph 模块导入主要组件
from .graph import (
AgentRunner,
run_agent_task,
LLMService,
AuditState,
create_audit_graph,
)
# 事件管理
from .event_manager import EventManager, AgentEventEmitter
# Agent 类
from .agents import (
BaseAgent, AgentConfig, AgentResult,
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
)
__all__ = [
# 核心 Runner
"AgentRunner",
"run_agent_task",
"LLMService",
# LangGraph
"AuditState",
"create_audit_graph",
# 事件管理
"EventManager",
"AgentEventEmitter",
# Agent 类
"BaseAgent",
"AgentConfig",
"AgentResult",
"OrchestratorAgent",
"ReconAgent",
"AnalysisAgent",
"VerificationAgent",
]

View File

@ -0,0 +1,21 @@
"""
混合 Agent 架构
包含 OrchestratorReconAnalysis Verification Agent
"""
from .base import BaseAgent, AgentConfig, AgentResult
from .orchestrator import OrchestratorAgent
from .recon import ReconAgent
from .analysis import AnalysisAgent
from .verification import VerificationAgent
__all__ = [
"BaseAgent",
"AgentConfig",
"AgentResult",
"OrchestratorAgent",
"ReconAgent",
"AnalysisAgent",
"VerificationAgent",
]

View File

@ -0,0 +1,469 @@
"""
Analysis Agent (漏洞分析层)
负责代码审计RAG 查询模式匹配数据流分析
类型: ReAct
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent负责深度代码安全分析。
## 你的职责
1. 使用静态分析工具快速扫描
2. 使用 RAG 进行语义代码搜索
3. 追踪数据流从用户输入到危险函数
4. 分析业务逻辑漏洞
5. 评估漏洞严重程度
## 你可以使用的工具
### 外部扫描工具
- semgrep_scan: Semgrep 静态分析推荐首先使用
- bandit_scan: Python 安全扫描
### RAG 语义搜索
- rag_query: 语义代码搜索
- security_search: 安全相关代码搜索
- function_context: 函数上下文分析
### 深度分析
- pattern_match: 危险模式匹配
- code_analysis: LLM 深度代码分析
- dataflow_analysis: 数据流追踪
- vulnerability_validation: 漏洞验证
### 文件操作
- read_file: 读取文件
- search_code: 关键字搜索
## 分析策略
1. **快速扫描**: 先用 Semgrep 快速发现问题
2. **语义搜索**: RAG 找到相关代码
3. **深度分析**: 对可疑代码进行 LLM 分析
4. **数据流追踪**: 追踪用户输入的流向
## 重点关注
- SQL 注入NoSQL 注入
- XSS反射型存储型DOM型
- 命令注入代码注入
- 路径遍历任意文件访问
- SSRFXXE
- 不安全的反序列化
- 认证/授权绕过
- 敏感信息泄露
## 输出格式
发现漏洞时返回结构化信息
```json
{
"findings": [
{
"vulnerability_type": "漏洞类型",
"severity": "critical/high/medium/low",
"title": "漏洞标题",
"description": "详细描述",
"file_path": "文件路径",
"line_start": 行号,
"code_snippet": "代码片段",
"source": "污点源",
"sink": "危险函数",
"suggestion": "修复建议",
"needs_verification": true/false
}
]
}
```
请系统性地分析代码发现真实的安全漏洞"""
class AnalysisAgent(BaseAgent):
"""
漏洞分析 Agent
使用 ReAct 模式进行迭代分析
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
config = AgentConfig(
name="Analysis",
agent_type=AgentType.ANALYSIS,
pattern=AgentPattern.REACT,
max_iterations=30,
system_prompt=ANALYSIS_SYSTEM_PROMPT,
tools=[
"semgrep_scan", "bandit_scan",
"rag_query", "security_search", "function_context",
"pattern_match", "code_analysis", "dataflow_analysis",
"vulnerability_validation",
"read_file", "search_code",
],
)
super().__init__(config, llm_service, tools, event_emitter)
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行漏洞分析"""
import time
start_time = time.time()
phase_name = input_data.get("phase_name", "analysis")
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
plan = input_data.get("plan", {})
previous_results = input_data.get("previous_results", {})
# 从之前的 Recon 结果获取信息
recon_data = previous_results.get("recon", {}).get("data", {})
high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", []))
tech_stack = recon_data.get("tech_stack", {})
entry_points = recon_data.get("entry_points", [])
try:
all_findings = []
# 1. 静态分析阶段
if phase_name in ["static_analysis", "analysis"]:
await self.emit_thinking("执行静态代码分析...")
static_findings = await self._run_static_analysis(tech_stack)
all_findings.extend(static_findings)
# 2. 深度分析阶段
if phase_name in ["deep_analysis", "analysis"]:
await self.emit_thinking("执行深度漏洞分析...")
# 分析入口点
deep_findings = await self._analyze_entry_points(entry_points)
all_findings.extend(deep_findings)
# 分析高风险区域
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
all_findings.extend(risk_findings)
# 语义搜索常见漏洞
vuln_types = config.get("target_vulnerabilities", [
"sql_injection", "xss", "command_injection",
"path_traversal", "ssrf", "hardcoded_secret",
])
for vuln_type in vuln_types[:5]: # 限制数量
if self.is_cancelled:
break
await self.emit_thinking(f"搜索 {vuln_type} 相关代码...")
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
all_findings.extend(vuln_findings)
# 去重
all_findings = self._deduplicate_findings(all_findings)
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"分析完成: 发现 {len(all_findings)} 个潜在漏洞"
)
return AgentResult(
success=True,
data={"findings": all_findings},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Analysis agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _run_static_analysis(self, tech_stack: Dict) -> List[Dict]:
"""运行静态分析工具"""
findings = []
# Semgrep 扫描
semgrep_tool = self.tools.get("semgrep_scan")
if semgrep_tool:
await self.emit_tool_call("semgrep_scan", {"rules": "p/security-audit"})
result = await semgrep_tool.execute(rules="p/security-audit", max_results=30)
if result.success and result.metadata.get("findings_count", 0) > 0:
for finding in result.metadata.get("findings", []):
findings.append({
"vulnerability_type": self._map_semgrep_rule(finding.get("check_id", "")),
"severity": self._map_semgrep_severity(finding.get("extra", {}).get("severity", "")),
"title": finding.get("check_id", "Semgrep Finding"),
"description": finding.get("extra", {}).get("message", ""),
"file_path": finding.get("path", ""),
"line_start": finding.get("start", {}).get("line", 0),
"code_snippet": finding.get("extra", {}).get("lines", ""),
"source": "semgrep",
"needs_verification": True,
})
# Bandit 扫描 (Python)
languages = tech_stack.get("languages", [])
if "Python" in languages:
bandit_tool = self.tools.get("bandit_scan")
if bandit_tool:
await self.emit_tool_call("bandit_scan", {})
result = await bandit_tool.execute()
if result.success and result.metadata.get("findings_count", 0) > 0:
for finding in result.metadata.get("findings", []):
findings.append({
"vulnerability_type": self._map_bandit_test(finding.get("test_id", "")),
"severity": finding.get("issue_severity", "medium").lower(),
"title": finding.get("test_name", "Bandit Finding"),
"description": finding.get("issue_text", ""),
"file_path": finding.get("filename", ""),
"line_start": finding.get("line_number", 0),
"code_snippet": finding.get("code", ""),
"source": "bandit",
"needs_verification": True,
})
return findings
async def _analyze_entry_points(self, entry_points: List[Dict]) -> List[Dict]:
"""分析入口点"""
findings = []
code_analysis_tool = self.tools.get("code_analysis")
read_tool = self.tools.get("read_file")
if not code_analysis_tool or not read_tool:
return findings
# 分析前几个入口点
for ep in entry_points[:10]:
if self.is_cancelled:
break
file_path = ep.get("file", "")
line = ep.get("line", 1)
if not file_path:
continue
# 读取文件内容
read_result = await read_tool.execute(
file_path=file_path,
start_line=max(1, line - 20),
end_line=line + 50,
)
if not read_result.success:
continue
# 深度分析
analysis_result = await code_analysis_tool.execute(
code=read_result.data,
file_path=file_path,
)
if analysis_result.success and analysis_result.metadata.get("issues"):
for issue in analysis_result.metadata["issues"]:
findings.append({
"vulnerability_type": issue.get("type", "unknown"),
"severity": issue.get("severity", "medium"),
"title": issue.get("title", "Security Issue"),
"description": issue.get("description", ""),
"file_path": file_path,
"line_start": issue.get("line", line),
"code_snippet": issue.get("code_snippet", ""),
"suggestion": issue.get("suggestion", ""),
"source": "code_analysis",
"needs_verification": True,
})
return findings
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
"""分析高风险区域"""
findings = []
pattern_tool = self.tools.get("pattern_match")
read_tool = self.tools.get("read_file")
search_tool = self.tools.get("search_code")
if not search_tool:
return findings
# 在高风险区域搜索危险模式
dangerous_patterns = [
("execute(", "sql_injection"),
("eval(", "code_injection"),
("system(", "command_injection"),
("exec(", "command_injection"),
("innerHTML", "xss"),
("document.write", "xss"),
]
for pattern, vuln_type in dangerous_patterns[:5]:
if self.is_cancelled:
break
result = await search_tool.execute(keyword=pattern, max_results=10)
if result.success and result.metadata.get("matches", 0) > 0:
for match in result.metadata.get("results", [])[:3]:
file_path = match.get("file", "")
# 检查是否在高风险区域
in_high_risk = any(
area in file_path for area in high_risk_areas
)
if in_high_risk or True: # 暂时包含所有
findings.append({
"vulnerability_type": vuln_type,
"severity": "high" if in_high_risk else "medium",
"title": f"疑似 {vuln_type}: {pattern}",
"description": f"{file_path} 中发现危险模式 {pattern}",
"file_path": file_path,
"line_start": match.get("line", 0),
"code_snippet": match.get("match", ""),
"source": "pattern_search",
"needs_verification": True,
})
return findings
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
"""搜索特定漏洞模式"""
findings = []
security_tool = self.tools.get("security_search")
if not security_tool:
return findings
result = await security_tool.execute(
vulnerability_type=vuln_type,
top_k=10,
)
if result.success and result.metadata.get("results_count", 0) > 0:
for item in result.metadata.get("results", [])[:5]:
findings.append({
"vulnerability_type": vuln_type,
"severity": "medium",
"title": f"疑似 {vuln_type}",
"description": f"通过语义搜索发现可能存在 {vuln_type}",
"file_path": item.get("file_path", ""),
"line_start": item.get("line_start", 0),
"code_snippet": item.get("content", "")[:500],
"source": "rag_search",
"needs_verification": True,
})
return findings
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
"""去重发现"""
seen = set()
unique = []
for finding in findings:
key = (
finding.get("file_path", ""),
finding.get("line_start", 0),
finding.get("vulnerability_type", ""),
)
if key not in seen:
seen.add(key)
unique.append(finding)
return unique
def _map_semgrep_rule(self, rule_id: str) -> str:
"""映射 Semgrep 规则到漏洞类型"""
rule_lower = rule_id.lower()
if "sql" in rule_lower:
return "sql_injection"
elif "xss" in rule_lower:
return "xss"
elif "command" in rule_lower or "injection" in rule_lower:
return "command_injection"
elif "path" in rule_lower or "traversal" in rule_lower:
return "path_traversal"
elif "ssrf" in rule_lower:
return "ssrf"
elif "deserial" in rule_lower:
return "deserialization"
elif "secret" in rule_lower or "password" in rule_lower or "key" in rule_lower:
return "hardcoded_secret"
elif "crypto" in rule_lower:
return "weak_crypto"
else:
return "other"
def _map_semgrep_severity(self, severity: str) -> str:
"""映射 Semgrep 严重程度"""
mapping = {
"ERROR": "high",
"WARNING": "medium",
"INFO": "low",
}
return mapping.get(severity, "medium")
def _map_bandit_test(self, test_id: str) -> str:
"""映射 Bandit 测试到漏洞类型"""
mappings = {
"B101": "assert_used",
"B102": "exec_used",
"B103": "hardcoded_password",
"B104": "hardcoded_bind_all",
"B105": "hardcoded_password",
"B106": "hardcoded_password",
"B107": "hardcoded_password",
"B108": "hardcoded_tmp",
"B301": "deserialization",
"B302": "deserialization",
"B303": "weak_crypto",
"B304": "weak_crypto",
"B305": "weak_crypto",
"B306": "weak_crypto",
"B307": "code_injection",
"B308": "code_injection",
"B310": "ssrf",
"B311": "weak_random",
"B312": "telnet",
"B501": "ssl_verify",
"B502": "ssl_verify",
"B503": "ssl_verify",
"B504": "ssl_verify",
"B505": "weak_crypto",
"B506": "yaml_load",
"B507": "ssh_key",
"B601": "command_injection",
"B602": "command_injection",
"B603": "command_injection",
"B604": "command_injection",
"B605": "command_injection",
"B606": "command_injection",
"B607": "command_injection",
"B608": "sql_injection",
"B609": "sql_injection",
"B610": "sql_injection",
"B611": "sql_injection",
"B701": "xss",
"B702": "xss",
"B703": "xss",
}
return mappings.get(test_id, "other")

View File

@ -0,0 +1,284 @@
"""
Agent 基类
定义 Agent 的基本接口和通用功能
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class AgentType(Enum):
"""Agent 类型"""
ORCHESTRATOR = "orchestrator"
RECON = "recon"
ANALYSIS = "analysis"
VERIFICATION = "verification"
class AgentPattern(Enum):
"""Agent 运行模式"""
REACT = "react" # 反应式:思考-行动-观察循环
PLAN_AND_EXECUTE = "plan_execute" # 计划执行:先规划后执行
@dataclass
class AgentConfig:
"""Agent 配置"""
name: str
agent_type: AgentType
pattern: AgentPattern = AgentPattern.REACT
# LLM 配置
model: Optional[str] = None
temperature: float = 0.1
max_tokens: int = 4096
# 执行限制
max_iterations: int = 20
timeout_seconds: int = 600
# 工具配置
tools: List[str] = field(default_factory=list)
# 系统提示词
system_prompt: Optional[str] = None
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class AgentResult:
"""Agent 执行结果"""
success: bool
data: Any = None
error: Optional[str] = None
# 执行统计
iterations: int = 0
tool_calls: int = 0
tokens_used: int = 0
duration_ms: int = 0
# 中间结果
intermediate_steps: List[Dict[str, Any]] = field(default_factory=list)
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
"data": self.data,
"error": self.error,
"iterations": self.iterations,
"tool_calls": self.tool_calls,
"tokens_used": self.tokens_used,
"duration_ms": self.duration_ms,
"metadata": self.metadata,
}
class BaseAgent(ABC):
"""
Agent 基类
所有 Agent 需要继承此类并实现核心方法
"""
def __init__(
self,
config: AgentConfig,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
"""
初始化 Agent
Args:
config: Agent 配置
llm_service: LLM 服务
tools: 可用工具字典
event_emitter: 事件发射器
"""
self.config = config
self.llm_service = llm_service
self.tools = tools
self.event_emitter = event_emitter
# 运行状态
self._iteration = 0
self._total_tokens = 0
self._tool_calls = 0
self._cancelled = False
@property
def name(self) -> str:
return self.config.name
@property
def agent_type(self) -> AgentType:
return self.config.agent_type
@abstractmethod
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
执行 Agent 任务
Args:
input_data: 输入数据
Returns:
Agent 执行结果
"""
pass
def cancel(self):
"""取消执行"""
self._cancelled = True
@property
def is_cancelled(self) -> bool:
return self._cancelled
async def emit_event(
self,
event_type: str,
message: str,
**kwargs
):
"""发射事件"""
if self.event_emitter:
from ..event_manager import AgentEventData
await self.event_emitter.emit(AgentEventData(
event_type=event_type,
message=message,
**kwargs
))
async def emit_thinking(self, message: str):
"""发射思考事件"""
await self.emit_event("thinking", f"[{self.name}] {message}")
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
"""发射工具调用事件"""
await self.emit_event(
"tool_call",
f"[{self.name}] 调用工具: {tool_name}",
tool_name=tool_name,
tool_input=tool_input,
)
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
"""发射工具结果事件"""
await self.emit_event(
"tool_result",
f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)",
tool_name=tool_name,
tool_duration_ms=duration_ms,
)
async def call_tool(self, tool_name: str, **kwargs) -> Any:
"""
调用工具
Args:
tool_name: 工具名称
**kwargs: 工具参数
Returns:
工具执行结果
"""
tool = self.tools.get(tool_name)
if not tool:
logger.warning(f"Tool not found: {tool_name}")
return None
self._tool_calls += 1
await self.emit_tool_call(tool_name, kwargs)
import time
start = time.time()
result = await tool.execute(**kwargs)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
return result
async def call_llm(
self,
messages: List[Dict[str, str]],
tools: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
"""
调用 LLM
Args:
messages: 消息列表
tools: 可用工具描述
Returns:
LLM 响应
"""
self._iteration += 1
# 这里应该调用实际的 LLM 服务
# 使用 LangChain 或直接调用 API
try:
response = await self.llm_service.chat_completion(
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
tools=tools,
)
if response.get("usage"):
self._total_tokens += response["usage"].get("total_tokens", 0)
return response
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise
def get_tool_descriptions(self) -> List[Dict[str, Any]]:
"""获取工具描述(用于 LLM"""
descriptions = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = {
"type": "function",
"function": {
"name": name,
"description": tool.description,
}
}
# 添加参数 schema
if hasattr(tool, 'args_schema') and tool.args_schema:
desc["function"]["parameters"] = tool.args_schema.schema()
descriptions.append(desc)
return descriptions
def get_stats(self) -> Dict[str, Any]:
"""获取执行统计"""
return {
"agent": self.name,
"type": self.agent_type.value,
"iterations": self._iteration,
"tool_calls": self._tool_calls,
"tokens_used": self._total_tokens,
}

View File

@ -0,0 +1,381 @@
"""
Orchestrator Agent (编排层)
负责任务分解 Agent 调度和结果汇总
类型: Plan-and-Execute
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
@dataclass
class AuditPlan:
"""审计计划"""
phases: List[Dict[str, Any]]
high_risk_areas: List[str]
focus_vulnerabilities: List[str]
estimated_steps: int
priority_files: List[str]
metadata: Dict[str, Any]
ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent负责协调整个安全审计流程。
## 你的职责
1. 分析项目信息制定审计计划
2. 调度子 AgentReconAnalysisVerification执行任务
3. 汇总审计结果生成报告
## 审计流程
1. **信息收集阶段**: 调度 Recon Agent 收集项目信息
- 项目结构分析
- 技术栈识别
- 入口点识别
- 依赖分析
2. **漏洞分析阶段**: 调度 Analysis Agent 进行代码分析
- 静态代码分析
- 语义搜索
- 模式匹配
- 数据流追踪
3. **漏洞验证阶段**: 调度 Verification Agent 验证发现
- 漏洞确认
- PoC 生成
- 沙箱测试
4. **报告生成阶段**: 汇总所有发现生成最终报告
## 输出格式
当生成审计计划时返回 JSON:
```json
{
"phases": [
{"name": "阶段名", "description": "描述", "agent": "agent_type"}
],
"high_risk_areas": ["高风险目录/文件"],
"focus_vulnerabilities": ["重点漏洞类型"],
"priority_files": ["优先审计的文件"],
"estimated_steps": 数字
}
```
请基于项目信息制定合理的审计计划"""
class OrchestratorAgent(BaseAgent):
"""
编排 Agent
使用 Plan-and-Execute 模式
1. 首先生成审计计划
2. 按计划调度子 Agent
3. 收集结果并汇总
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
sub_agents: Optional[Dict[str, BaseAgent]] = None,
):
config = AgentConfig(
name="Orchestrator",
agent_type=AgentType.ORCHESTRATOR,
pattern=AgentPattern.PLAN_AND_EXECUTE,
max_iterations=10,
system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
)
super().__init__(config, llm_service, tools, event_emitter)
self.sub_agents = sub_agents or {}
def register_sub_agent(self, name: str, agent: BaseAgent):
"""注册子 Agent"""
self.sub_agents[name] = agent
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
执行编排任务
Args:
input_data: {
"project_info": 项目信息,
"config": 审计配置,
}
"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
try:
await self.emit_thinking("开始制定审计计划...")
# 1. 生成审计计划
plan = await self._create_audit_plan(project_info, config)
if not plan:
return AgentResult(
success=False,
error="无法生成审计计划",
)
await self.emit_event(
"planning",
f"审计计划已生成,共 {len(plan.phases)} 个阶段",
metadata={"plan": plan.__dict__}
)
# 2. 执行各阶段
all_findings = []
phase_results = {}
for phase in plan.phases:
if self.is_cancelled:
break
phase_name = phase.get("name", "unknown")
agent_type = phase.get("agent", "analysis")
await self.emit_event(
"phase_start",
f"开始 {phase_name} 阶段",
phase=phase_name
)
# 调度对应的子 Agent
result = await self._execute_phase(
phase_name=phase_name,
agent_type=agent_type,
project_info=project_info,
config=config,
plan=plan,
previous_results=phase_results,
)
phase_results[phase_name] = result
if result.success and result.data:
if isinstance(result.data, dict):
findings = result.data.get("findings", [])
all_findings.extend(findings)
await self.emit_event(
"phase_complete",
f"{phase_name} 阶段完成",
phase=phase_name
)
# 3. 汇总结果
await self.emit_thinking("汇总审计结果...")
summary = await self._generate_summary(
plan=plan,
phase_results=phase_results,
all_findings=all_findings,
)
duration_ms = int((time.time() - start_time) * 1000)
return AgentResult(
success=True,
data={
"plan": plan.__dict__,
"findings": all_findings,
"summary": summary,
"phase_results": {k: v.to_dict() for k, v in phase_results.items()},
},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Orchestrator failed: {e}", exc_info=True)
return AgentResult(
success=False,
error=str(e),
)
async def _create_audit_plan(
self,
project_info: Dict[str, Any],
config: Dict[str, Any],
) -> Optional[AuditPlan]:
"""生成审计计划"""
# 构建 prompt
prompt = f"""基于以下项目信息,制定安全审计计划。
## 项目信息
- 名称: {project_info.get('name', 'unknown')}
- 语言: {project_info.get('languages', [])}
- 文件数量: {project_info.get('file_count', 0)}
- 目录结构: {project_info.get('structure', {})}
## 用户配置
- 目标漏洞: {config.get('target_vulnerabilities', [])}
- 验证级别: {config.get('verification_level', 'sandbox')}
- 排除模式: {config.get('exclude_patterns', [])}
请生成审计计划返回 JSON 格式"""
try:
# 调用 LLM
messages = [
{"role": "system", "content": self.config.system_prompt},
{"role": "user", "content": prompt},
]
response = await self.llm_service.chat_completion_raw(
messages=messages,
temperature=0.1,
max_tokens=2000,
)
content = response.get("content", "")
# 解析 JSON
import json
import re
# 提取 JSON
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
plan_data = json.loads(json_match.group())
return AuditPlan(
phases=plan_data.get("phases", self._default_phases()),
high_risk_areas=plan_data.get("high_risk_areas", []),
focus_vulnerabilities=plan_data.get("focus_vulnerabilities", []),
estimated_steps=plan_data.get("estimated_steps", 30),
priority_files=plan_data.get("priority_files", []),
metadata=plan_data,
)
else:
# 使用默认计划
return AuditPlan(
phases=self._default_phases(),
high_risk_areas=["src/", "api/", "controllers/", "routes/"],
focus_vulnerabilities=["sql_injection", "xss", "command_injection"],
estimated_steps=30,
priority_files=[],
metadata={},
)
except Exception as e:
logger.error(f"Failed to create audit plan: {e}")
return AuditPlan(
phases=self._default_phases(),
high_risk_areas=[],
focus_vulnerabilities=[],
estimated_steps=30,
priority_files=[],
metadata={},
)
def _default_phases(self) -> List[Dict[str, Any]]:
"""默认审计阶段"""
return [
{
"name": "recon",
"description": "信息收集 - 分析项目结构和技术栈",
"agent": "recon",
},
{
"name": "static_analysis",
"description": "静态分析 - 使用外部工具快速扫描",
"agent": "analysis",
},
{
"name": "deep_analysis",
"description": "深度分析 - AI 驱动的代码审计",
"agent": "analysis",
},
{
"name": "verification",
"description": "漏洞验证 - 确认发现的漏洞",
"agent": "verification",
},
]
async def _execute_phase(
self,
phase_name: str,
agent_type: str,
project_info: Dict[str, Any],
config: Dict[str, Any],
plan: AuditPlan,
previous_results: Dict[str, AgentResult],
) -> AgentResult:
"""执行审计阶段"""
agent = self.sub_agents.get(agent_type)
if not agent:
logger.warning(f"Agent not found: {agent_type}")
return AgentResult(success=False, error=f"Agent {agent_type} not found")
# 构建阶段输入
phase_input = {
"phase_name": phase_name,
"project_info": project_info,
"config": config,
"plan": plan.__dict__,
"previous_results": {k: v.to_dict() for k, v in previous_results.items()},
}
# 执行子 Agent
return await agent.run(phase_input)
async def _generate_summary(
self,
plan: AuditPlan,
phase_results: Dict[str, AgentResult],
all_findings: List[Dict],
) -> Dict[str, Any]:
"""生成审计摘要"""
# 统计漏洞
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
type_counts = {}
verified_count = 0
for finding in all_findings:
sev = finding.get("severity", "low")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
vtype = finding.get("vulnerability_type", "other")
type_counts[vtype] = type_counts.get(vtype, 0) + 1
if finding.get("is_verified"):
verified_count += 1
# 计算安全评分
base_score = 100
deductions = (
severity_counts["critical"] * 20 +
severity_counts["high"] * 10 +
severity_counts["medium"] * 5 +
severity_counts["low"] * 2
)
security_score = max(0, base_score - deductions)
return {
"total_findings": len(all_findings),
"verified_count": verified_count,
"severity_distribution": severity_counts,
"vulnerability_types": type_counts,
"security_score": security_score,
"phases_completed": len(phase_results),
"high_risk_areas": plan.high_risk_areas,
}

View File

@ -0,0 +1,435 @@
"""
Recon Agent (信息收集层)
负责项目结构分析技术栈识别入口点识别
类型: ReAct
"""
import asyncio
import logging
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前收集项目信息。
## 你的职责
1. 分析项目结构和目录布局
2. 识别使用的技术栈和框架
3. 找出应用程序入口点
4. 分析依赖和第三方库
5. 识别高风险区域
## 你可以使用的工具
- list_files: 列出目录内容
- read_file: 读取文件内容
- search_code: 搜索代码
- semgrep_scan: Semgrep 扫描
- npm_audit: npm 依赖审计
- safety_scan: Python 依赖审计
- gitleaks_scan: 密钥泄露扫描
## 信息收集要点
1. **目录结构**: 了解项目布局识别源码配置测试目录
2. **技术栈**: 检测语言框架数据库等
3. **入口点**: API 路由控制器处理函数
4. **配置文件**: 环境变量数据库配置API 密钥
5. **依赖**: package.json, requirements.txt, go.mod
6. **安全相关**: 认证授权加密相关代码
## 输出格式
完成后返回 JSON:
```json
{
"project_structure": {...},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": []
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {...},
"initial_findings": []
}
```
请系统性地收集信息为后续分析做准备"""
class ReconAgent(BaseAgent):
"""
信息收集 Agent
使用 ReAct 模式迭代收集项目信息
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
config = AgentConfig(
name="Recon",
agent_type=AgentType.RECON,
pattern=AgentPattern.REACT,
max_iterations=15,
system_prompt=RECON_SYSTEM_PROMPT,
tools=[
"list_files", "read_file", "search_code",
"semgrep_scan", "npm_audit", "safety_scan",
"gitleaks_scan", "osv_scan",
],
)
super().__init__(config, llm_service, tools, event_emitter)
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行信息收集"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
try:
await self.emit_thinking("开始信息收集...")
# 收集结果
result_data = {
"project_structure": {},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": [],
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"initial_findings": [],
}
# 1. 分析项目结构
await self.emit_thinking("分析项目结构...")
structure = await self._analyze_structure()
result_data["project_structure"] = structure
# 2. 识别技术栈
await self.emit_thinking("识别技术栈...")
tech_stack = await self._identify_tech_stack(structure)
result_data["tech_stack"] = tech_stack
# 3. 扫描依赖漏洞
await self.emit_thinking("扫描依赖漏洞...")
deps_result = await self._scan_dependencies(tech_stack)
result_data["dependencies"] = deps_result.get("dependencies", {})
if deps_result.get("findings"):
result_data["initial_findings"].extend(deps_result["findings"])
# 4. 快速密钥扫描
await self.emit_thinking("扫描密钥泄露...")
secrets_result = await self._scan_secrets()
if secrets_result.get("findings"):
result_data["initial_findings"].extend(secrets_result["findings"])
# 5. 识别入口点
await self.emit_thinking("识别入口点...")
entry_points = await self._identify_entry_points(tech_stack)
result_data["entry_points"] = entry_points
# 6. 识别高风险区域
result_data["high_risk_areas"] = self._identify_high_risk_areas(
structure, tech_stack, entry_points
)
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
f"{len(result_data['initial_findings'])} 个初步发现"
)
return AgentResult(
success=True,
data=result_data,
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Recon agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _analyze_structure(self) -> Dict[str, Any]:
"""分析项目结构"""
structure = {
"directories": [],
"files_by_type": {},
"config_files": [],
"total_files": 0,
}
# 列出根目录
list_tool = self.tools.get("list_files")
if not list_tool:
return structure
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
if result.success:
structure["total_files"] = result.metadata.get("file_count", 0)
# 识别配置文件
config_patterns = [
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
"docker-compose.yml", "Dockerfile",
]
# 从输出中解析文件列表
if isinstance(result.data, str):
for line in result.data.split('\n'):
line = line.strip()
for pattern in config_patterns:
if pattern in line:
structure["config_files"].append(line)
return structure
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
"""识别技术栈"""
tech_stack = {
"languages": [],
"frameworks": [],
"databases": [],
"package_managers": [],
}
config_files = structure.get("config_files", [])
# 基于配置文件推断
for cfg in config_files:
if "package.json" in cfg:
tech_stack["languages"].append("JavaScript/TypeScript")
tech_stack["package_managers"].append("npm")
elif "requirements.txt" in cfg or "setup.py" in cfg:
tech_stack["languages"].append("Python")
tech_stack["package_managers"].append("pip")
elif "go.mod" in cfg:
tech_stack["languages"].append("Go")
elif "Cargo.toml" in cfg:
tech_stack["languages"].append("Rust")
elif "pom.xml" in cfg or "build.gradle" in cfg:
tech_stack["languages"].append("Java")
# 读取 package.json 识别框架
read_tool = self.tools.get("read_file")
if read_tool and "package.json" in str(config_files):
result = await read_tool.execute(file_path="package.json", max_lines=100)
if result.success:
content = result.data
if "react" in content.lower():
tech_stack["frameworks"].append("React")
if "vue" in content.lower():
tech_stack["frameworks"].append("Vue")
if "express" in content.lower():
tech_stack["frameworks"].append("Express")
if "fastify" in content.lower():
tech_stack["frameworks"].append("Fastify")
if "next" in content.lower():
tech_stack["frameworks"].append("Next.js")
# 读取 requirements.txt 识别框架
if read_tool and "requirements.txt" in str(config_files):
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
if result.success:
content = result.data.lower()
if "django" in content:
tech_stack["frameworks"].append("Django")
if "flask" in content:
tech_stack["frameworks"].append("Flask")
if "fastapi" in content:
tech_stack["frameworks"].append("FastAPI")
if "sqlalchemy" in content:
tech_stack["databases"].append("SQLAlchemy")
if "pymongo" in content:
tech_stack["databases"].append("MongoDB")
# 去重
tech_stack["languages"] = list(set(tech_stack["languages"]))
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
tech_stack["databases"] = list(set(tech_stack["databases"]))
return tech_stack
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
"""扫描依赖漏洞"""
result = {
"dependencies": {},
"findings": [],
}
# npm audit
if "npm" in tech_stack.get("package_managers", []):
npm_tool = self.tools.get("npm_audit")
if npm_tool:
npm_result = await npm_tool.execute()
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["npm"] = npm_result.metadata
# 转换为发现格式
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
if count > 0 and sev in ["critical", "high"]:
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": sev,
"title": f"npm 依赖漏洞 ({count}{sev})",
"source": "npm_audit",
})
# Safety (Python)
if "pip" in tech_stack.get("package_managers", []):
safety_tool = self.tools.get("safety_scan")
if safety_tool:
safety_result = await safety_tool.execute()
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["pip"] = safety_result.metadata
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": "high",
"title": f"Python 依赖漏洞",
"source": "safety",
})
# OSV Scanner
osv_tool = self.tools.get("osv_scan")
if osv_tool:
osv_result = await osv_tool.execute()
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["osv"] = osv_result.metadata
return result
async def _scan_secrets(self) -> Dict[str, Any]:
"""扫描密钥泄露"""
result = {"findings": []}
gitleaks_tool = self.tools.get("gitleaks_scan")
if gitleaks_tool:
gl_result = await gitleaks_tool.execute()
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
for finding in gl_result.metadata.get("findings", []):
result["findings"].append({
"vulnerability_type": "hardcoded_secret",
"severity": "high",
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
"file_path": finding.get("file"),
"line_start": finding.get("line"),
"source": "gitleaks",
})
return result
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
"""识别入口点"""
entry_points = []
search_tool = self.tools.get("search_code")
if not search_tool:
return entry_points
# 基于框架搜索入口点
search_patterns = []
frameworks = tech_stack.get("frameworks", [])
if "Express" in frameworks:
search_patterns.extend([
("app.get(", "Express GET route"),
("app.post(", "Express POST route"),
("router.get(", "Express router GET"),
("router.post(", "Express router POST"),
])
if "FastAPI" in frameworks:
search_patterns.extend([
("@app.get(", "FastAPI GET endpoint"),
("@app.post(", "FastAPI POST endpoint"),
("@router.get(", "FastAPI router GET"),
("@router.post(", "FastAPI router POST"),
])
if "Django" in frameworks:
search_patterns.extend([
("def get(self", "Django GET view"),
("def post(self", "Django POST view"),
("path(", "Django URL pattern"),
])
if "Flask" in frameworks:
search_patterns.extend([
("@app.route(", "Flask route"),
("@blueprint.route(", "Flask blueprint route"),
])
# 通用模式
search_patterns.extend([
("def handle", "Handler function"),
("async def handle", "Async handler"),
("class.*Controller", "Controller class"),
("class.*Handler", "Handler class"),
])
for pattern, description in search_patterns[:10]: # 限制搜索数量
result = await search_tool.execute(keyword=pattern, max_results=10)
if result.success and result.metadata.get("matches", 0) > 0:
for match in result.metadata.get("results", [])[:5]:
entry_points.append({
"type": description,
"file": match.get("file"),
"line": match.get("line"),
"pattern": pattern,
})
return entry_points[:30] # 限制总数
def _identify_high_risk_areas(
self,
structure: Dict,
tech_stack: Dict,
entry_points: List[Dict],
) -> List[str]:
"""识别高风险区域"""
high_risk = set()
# 通用高风险目录
risk_dirs = [
"auth/", "authentication/", "login/",
"api/", "routes/", "controllers/", "handlers/",
"db/", "database/", "models/",
"admin/", "management/",
"upload/", "file/",
"payment/", "billing/",
]
for dir_name in risk_dirs:
high_risk.add(dir_name)
# 从入口点提取目录
for ep in entry_points:
file_path = ep.get("file", "")
if "/" in file_path:
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
high_risk.add(dir_path)
return list(high_risk)[:20]

View File

@ -0,0 +1,392 @@
"""
Verification Agent (漏洞验证层)
负责漏洞确认PoC 生成沙箱测试
类型: ReAct
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime, timezone
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent负责确认发现的漏洞是否真实存在。
## 你的职责
1. 分析漏洞上下文判断是否为真正的安全问题
2. 构造 PoC概念验证代码
3. 在沙箱中执行测试
4. 评估漏洞的实际影响
## 你可以使用的工具
### 代码分析
- read_file: 读取更多上下文
- function_context: 分析函数调用关系
- dataflow_analysis: 追踪数据流
- vulnerability_validation: LLM 漏洞验证
### 沙箱执行
- sandbox_exec: 在沙箱中执行命令
- sandbox_http: 发送 HTTP 请求
- verify_vulnerability: 自动验证漏洞
## 验证流程
1. **上下文分析**: 获取更多代码上下文
2. **可利用性分析**: 判断漏洞是否可被利用
3. **PoC 构造**: 设计验证方案
4. **沙箱测试**: 在隔离环境中测试
5. **结果评估**: 确定漏洞是否真实存在
## 验证标准
- **确认 (confirmed)**: 漏洞真实存在且可利用
- **可能 (likely)**: 高度可能存在漏洞
- **不确定 (uncertain)**: 需要更多信息
- **误报 (false_positive)**: 确认是误报
## 输出格式
```json
{
"findings": [
{
"original_finding": {...},
"verdict": "confirmed/likely/uncertain/false_positive",
"confidence": 0.0-1.0,
"is_verified": true/false,
"verification_method": "描述验证方法",
"poc": {
"code": "PoC 代码",
"description": "描述",
"steps": ["步骤1", "步骤2"]
},
"impact": "影响分析",
"recommendation": "修复建议"
}
]
}
```
请谨慎验证减少误报同时不遗漏真正的漏洞"""
class VerificationAgent(BaseAgent):
"""
漏洞验证 Agent
使用 ReAct 模式验证发现的漏洞
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
config = AgentConfig(
name="Verification",
agent_type=AgentType.VERIFICATION,
pattern=AgentPattern.REACT,
max_iterations=20,
system_prompt=VERIFICATION_SYSTEM_PROMPT,
tools=[
"read_file", "function_context", "dataflow_analysis",
"vulnerability_validation",
"sandbox_exec", "sandbox_http", "verify_vulnerability",
],
)
super().__init__(config, llm_service, tools, event_emitter)
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行漏洞验证"""
import time
start_time = time.time()
previous_results = input_data.get("previous_results", {})
config = input_data.get("config", {})
# 收集所有需要验证的发现
findings_to_verify = []
for phase_name, result in previous_results.items():
if isinstance(result, dict):
data = result.get("data", {})
else:
data = result.data if hasattr(result, 'data') else {}
if isinstance(data, dict):
phase_findings = data.get("findings", [])
for f in phase_findings:
if f.get("needs_verification", True):
findings_to_verify.append(f)
# 去重
findings_to_verify = self._deduplicate(findings_to_verify)
if not findings_to_verify:
await self.emit_event("info", "没有需要验证的发现")
return AgentResult(
success=True,
data={"findings": [], "verified_count": 0},
)
await self.emit_event(
"info",
f"开始验证 {len(findings_to_verify)} 个发现"
)
try:
verified_findings = []
verification_level = config.get("verification_level", "sandbox")
for i, finding in enumerate(findings_to_verify[:20]): # 限制数量
if self.is_cancelled:
break
await self.emit_thinking(
f"验证 [{i+1}/{min(len(findings_to_verify), 20)}]: {finding.get('title', 'unknown')}"
)
# 执行验证
verified = await self._verify_finding(finding, verification_level)
verified_findings.append(verified)
# 发射事件
if verified.get("is_verified"):
await self.emit_event(
"finding_verified",
f"✅ 已确认: {verified.get('title', '')}",
finding_id=verified.get("id"),
metadata={"severity": verified.get("severity")}
)
elif verified.get("verdict") == "false_positive":
await self.emit_event(
"finding_false_positive",
f"❌ 误报: {verified.get('title', '')}",
finding_id=verified.get("id"),
)
# 统计
confirmed_count = len([f for f in verified_findings if f.get("is_verified")])
likely_count = len([f for f in verified_findings if f.get("verdict") == "likely"])
false_positive_count = len([f for f in verified_findings if f.get("verdict") == "false_positive"])
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"验证完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
)
return AgentResult(
success=True,
data={
"findings": verified_findings,
"verified_count": confirmed_count,
"likely_count": likely_count,
"false_positive_count": false_positive_count,
},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Verification agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _verify_finding(
self,
finding: Dict[str, Any],
verification_level: str,
) -> Dict[str, Any]:
"""验证单个发现"""
result = {
**finding,
"verdict": "uncertain",
"confidence": 0.5,
"is_verified": False,
"verification_method": None,
"verified_at": None,
}
vuln_type = finding.get("vulnerability_type", "")
file_path = finding.get("file_path", "")
line_start = finding.get("line_start", 0)
code_snippet = finding.get("code_snippet", "")
try:
# 1. 获取更多上下文
context = await self._get_context(file_path, line_start)
# 2. LLM 验证
validation_result = await self._llm_validation(
finding, context
)
result["verdict"] = validation_result.get("verdict", "uncertain")
result["confidence"] = validation_result.get("confidence", 0.5)
result["verification_method"] = "llm_analysis"
# 3. 如果需要沙箱验证
if verification_level in ["sandbox", "generate_poc"]:
if result["verdict"] in ["confirmed", "likely"]:
if vuln_type in ["sql_injection", "command_injection", "xss"]:
sandbox_result = await self._sandbox_verification(
finding, validation_result
)
if sandbox_result.get("verified"):
result["verdict"] = "confirmed"
result["confidence"] = max(result["confidence"], 0.9)
result["verification_method"] = "sandbox_test"
result["poc"] = sandbox_result.get("poc")
# 4. 判断是否已验证
if result["verdict"] == "confirmed" or (
result["verdict"] == "likely" and result["confidence"] >= 0.8
):
result["is_verified"] = True
result["verified_at"] = datetime.now(timezone.utc).isoformat()
# 5. 添加修复建议
if result["is_verified"]:
result["recommendation"] = self._get_recommendation(vuln_type)
except Exception as e:
logger.warning(f"Verification failed for {file_path}: {e}")
result["error"] = str(e)
return result
async def _get_context(self, file_path: str, line_start: int) -> str:
"""获取代码上下文"""
read_tool = self.tools.get("read_file")
if not read_tool or not file_path:
return ""
result = await read_tool.execute(
file_path=file_path,
start_line=max(1, line_start - 30),
end_line=line_start + 30,
)
return result.data if result.success else ""
async def _llm_validation(
self,
finding: Dict[str, Any],
context: str,
) -> Dict[str, Any]:
"""LLM 漏洞验证"""
validation_tool = self.tools.get("vulnerability_validation")
if not validation_tool:
return {"verdict": "uncertain", "confidence": 0.5}
code = finding.get("code_snippet", "") or context[:2000]
result = await validation_tool.execute(
code=code,
vulnerability_type=finding.get("vulnerability_type", "unknown"),
file_path=finding.get("file_path", ""),
line_number=finding.get("line_start"),
context=context[:1000] if context else None,
)
if result.success and result.metadata.get("validation"):
validation = result.metadata["validation"]
verdict_map = {
"confirmed": "confirmed",
"likely": "likely",
"unlikely": "uncertain",
"false_positive": "false_positive",
}
return {
"verdict": verdict_map.get(validation.get("verdict", ""), "uncertain"),
"confidence": validation.get("confidence", 0.5),
"explanation": validation.get("detailed_analysis", ""),
"exploitation_conditions": validation.get("exploitation_conditions", []),
"poc_idea": validation.get("poc_idea"),
}
return {"verdict": "uncertain", "confidence": 0.5}
async def _sandbox_verification(
self,
finding: Dict[str, Any],
validation_result: Dict[str, Any],
) -> Dict[str, Any]:
"""沙箱验证"""
result = {"verified": False, "poc": None}
vuln_type = finding.get("vulnerability_type", "")
poc_idea = validation_result.get("poc_idea", "")
# 根据漏洞类型选择验证方法
sandbox_tool = self.tools.get("sandbox_exec")
http_tool = self.tools.get("sandbox_http")
verify_tool = self.tools.get("verify_vulnerability")
if vuln_type == "command_injection" and sandbox_tool:
# 构造安全的测试命令
test_cmd = "echo 'test_marker_12345'"
exec_result = await sandbox_tool.execute(
command=f"python3 -c \"print('test')\"",
timeout=10,
)
if exec_result.success:
result["verified"] = True
result["poc"] = {
"description": "命令注入测试",
"method": "sandbox_exec",
}
elif vuln_type in ["sql_injection", "xss"] and verify_tool:
# 使用自动验证工具
# 注意:这需要实际的目标 URL
pass
return result
def _get_recommendation(self, vuln_type: str) -> str:
"""获取修复建议"""
recommendations = {
"sql_injection": "使用参数化查询或 ORM避免字符串拼接构造 SQL",
"xss": "对用户输入进行 HTML 转义,使用 CSP避免 innerHTML",
"command_injection": "避免使用 shell=True使用参数列表传递命令",
"path_traversal": "验证和规范化路径,使用白名单,避免直接使用用户输入",
"ssrf": "验证和限制目标 URL使用白名单禁止内网访问",
"deserialization": "避免反序列化不可信数据,使用 JSON 替代 pickle/yaml",
"hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息",
"weak_crypto": "使用强加密算法AES-256, SHA-256+),避免 MD5/SHA1",
}
return recommendations.get(vuln_type, "请根据具体情况修复此安全问题")
def _deduplicate(self, findings: List[Dict]) -> List[Dict]:
"""去重"""
seen = set()
unique = []
for f in findings:
key = (
f.get("file_path", ""),
f.get("line_start", 0),
f.get("vulnerability_type", ""),
)
if key not in seen:
seen.add(key)
unique.append(f)
return unique

View File

@ -0,0 +1,371 @@
"""
Agent 事件管理器
负责事件的创建存储和推送
"""
import asyncio
import json
import logging
from typing import Optional, Dict, Any, List, AsyncGenerator, Callable
from datetime import datetime, timezone
from dataclasses import dataclass
import uuid
logger = logging.getLogger(__name__)
@dataclass
class AgentEventData:
"""Agent 事件数据"""
event_type: str
phase: Optional[str] = None
message: Optional[str] = None
tool_name: Optional[str] = None
tool_input: Optional[Dict[str, Any]] = None
tool_output: Optional[Dict[str, Any]] = None
tool_duration_ms: Optional[int] = None
finding_id: Optional[str] = None
tokens_used: int = 0
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
return {
"event_type": self.event_type,
"phase": self.phase,
"message": self.message,
"tool_name": self.tool_name,
"tool_input": self.tool_input,
"tool_output": self.tool_output,
"tool_duration_ms": self.tool_duration_ms,
"finding_id": self.finding_id,
"tokens_used": self.tokens_used,
"metadata": self.metadata,
}
class AgentEventEmitter:
"""
Agent 事件发射器
用于在 Agent 执行过程中发射事件
"""
def __init__(self, task_id: str, event_manager: 'EventManager'):
self.task_id = task_id
self.event_manager = event_manager
self._sequence = 0
self._current_phase = None
async def emit(self, event_data: AgentEventData):
"""发射事件"""
self._sequence += 1
event_data.phase = event_data.phase or self._current_phase
await self.event_manager.add_event(
task_id=self.task_id,
sequence=self._sequence,
**event_data.to_dict()
)
async def emit_phase_start(self, phase: str, message: Optional[str] = None):
"""发射阶段开始事件"""
self._current_phase = phase
await self.emit(AgentEventData(
event_type="phase_start",
phase=phase,
message=message or f"开始 {phase} 阶段",
))
async def emit_phase_complete(self, phase: str, message: Optional[str] = None):
"""发射阶段完成事件"""
await self.emit(AgentEventData(
event_type="phase_complete",
phase=phase,
message=message or f"{phase} 阶段完成",
))
async def emit_thinking(self, message: str, metadata: Optional[Dict] = None):
"""发射思考事件"""
await self.emit(AgentEventData(
event_type="thinking",
message=message,
metadata=metadata,
))
async def emit_tool_call(
self,
tool_name: str,
tool_input: Dict[str, Any],
message: Optional[str] = None,
):
"""发射工具调用事件"""
await self.emit(AgentEventData(
event_type="tool_call",
tool_name=tool_name,
tool_input=tool_input,
message=message or f"调用工具: {tool_name}",
))
async def emit_tool_result(
self,
tool_name: str,
tool_output: Any,
duration_ms: int,
message: Optional[str] = None,
):
"""发射工具结果事件"""
# 处理输出,确保可序列化
if hasattr(tool_output, 'to_dict'):
output_data = tool_output.to_dict()
elif isinstance(tool_output, str):
output_data = {"result": tool_output[:2000]} # 截断长输出
else:
output_data = {"result": str(tool_output)[:2000]}
await self.emit(AgentEventData(
event_type="tool_result",
tool_name=tool_name,
tool_output=output_data,
tool_duration_ms=duration_ms,
message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)",
))
async def emit_finding(
self,
finding_id: str,
title: str,
severity: str,
vulnerability_type: str,
is_verified: bool = False,
):
"""发射漏洞发现事件"""
event_type = "finding_verified" if is_verified else "finding_new"
await self.emit(AgentEventData(
event_type=event_type,
finding_id=finding_id,
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
metadata={
"title": title,
"severity": severity,
"vulnerability_type": vulnerability_type,
"is_verified": is_verified,
},
))
async def emit_info(self, message: str, metadata: Optional[Dict] = None):
"""发射信息事件"""
await self.emit(AgentEventData(
event_type="info",
message=message,
metadata=metadata,
))
async def emit_warning(self, message: str, metadata: Optional[Dict] = None):
"""发射警告事件"""
await self.emit(AgentEventData(
event_type="warning",
message=message,
metadata=metadata,
))
async def emit_error(self, message: str, metadata: Optional[Dict] = None):
"""发射错误事件"""
await self.emit(AgentEventData(
event_type="error",
message=message,
metadata=metadata,
))
async def emit_progress(
self,
current: int,
total: int,
message: Optional[str] = None,
):
"""发射进度事件"""
percentage = (current / total * 100) if total > 0 else 0
await self.emit(AgentEventData(
event_type="progress",
message=message or f"进度: {current}/{total} ({percentage:.1f}%)",
metadata={
"current": current,
"total": total,
"percentage": percentage,
},
))
class EventManager:
"""
事件管理器
负责事件的存储和检索
"""
def __init__(self, db_session_factory=None):
self.db_session_factory = db_session_factory
self._event_queues: Dict[str, asyncio.Queue] = {}
self._event_callbacks: Dict[str, List[Callable]] = {}
async def add_event(
self,
task_id: str,
event_type: str,
sequence: int = 0,
phase: Optional[str] = None,
message: Optional[str] = None,
tool_name: Optional[str] = None,
tool_input: Optional[Dict] = None,
tool_output: Optional[Dict] = None,
tool_duration_ms: Optional[int] = None,
finding_id: Optional[str] = None,
tokens_used: int = 0,
metadata: Optional[Dict] = None,
):
"""添加事件"""
event_id = str(uuid.uuid4())
timestamp = datetime.now(timezone.utc)
event_data = {
"id": event_id,
"task_id": task_id,
"event_type": event_type,
"sequence": sequence,
"phase": phase,
"message": message,
"tool_name": tool_name,
"tool_input": tool_input,
"tool_output": tool_output,
"tool_duration_ms": tool_duration_ms,
"finding_id": finding_id,
"tokens_used": tokens_used,
"metadata": metadata,
"timestamp": timestamp.isoformat(),
}
# 保存到数据库
if self.db_session_factory:
try:
await self._save_event_to_db(event_data)
except Exception as e:
logger.error(f"Failed to save event to database: {e}")
# 推送到队列
if task_id in self._event_queues:
await self._event_queues[task_id].put(event_data)
# 调用回调
if task_id in self._event_callbacks:
for callback in self._event_callbacks[task_id]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(event_data)
else:
callback(event_data)
except Exception as e:
logger.error(f"Event callback error: {e}")
return event_id
async def _save_event_to_db(self, event_data: Dict):
"""保存事件到数据库"""
from app.models.agent_task import AgentEvent
async with self.db_session_factory() as db:
event = AgentEvent(
id=event_data["id"],
task_id=event_data["task_id"],
event_type=event_data["event_type"],
sequence=event_data["sequence"],
phase=event_data["phase"],
message=event_data["message"],
tool_name=event_data["tool_name"],
tool_input=event_data["tool_input"],
tool_output=event_data["tool_output"],
tool_duration_ms=event_data["tool_duration_ms"],
finding_id=event_data["finding_id"],
tokens_used=event_data["tokens_used"],
event_metadata=event_data["metadata"],
)
db.add(event)
await db.commit()
def create_queue(self, task_id: str) -> asyncio.Queue:
"""创建事件队列"""
if task_id not in self._event_queues:
self._event_queues[task_id] = asyncio.Queue()
return self._event_queues[task_id]
def remove_queue(self, task_id: str):
"""移除事件队列"""
if task_id in self._event_queues:
del self._event_queues[task_id]
def add_callback(self, task_id: str, callback: Callable):
"""添加事件回调"""
if task_id not in self._event_callbacks:
self._event_callbacks[task_id] = []
self._event_callbacks[task_id].append(callback)
def remove_callback(self, task_id: str, callback: Callable):
"""移除事件回调"""
if task_id in self._event_callbacks:
self._event_callbacks[task_id].remove(callback)
async def get_events(
self,
task_id: str,
after_sequence: int = 0,
limit: int = 100,
) -> List[Dict]:
"""获取事件列表"""
if not self.db_session_factory:
return []
from sqlalchemy.future import select
from app.models.agent_task import AgentEvent
async with self.db_session_factory() as db:
result = await db.execute(
select(AgentEvent)
.where(AgentEvent.task_id == task_id)
.where(AgentEvent.sequence > after_sequence)
.order_by(AgentEvent.sequence)
.limit(limit)
)
events = result.scalars().all()
return [event.to_sse_dict() for event in events]
async def stream_events(
self,
task_id: str,
after_sequence: int = 0,
) -> AsyncGenerator[Dict, None]:
"""流式获取事件"""
queue = self.create_queue(task_id)
# 先发送历史事件
history = await self.get_events(task_id, after_sequence)
for event in history:
yield event
# 然后实时推送新事件
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=30)
yield event
# 检查是否是结束事件
if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]:
break
except asyncio.TimeoutError:
# 发送心跳
yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
finally:
self.remove_queue(task_id)
def create_emitter(self, task_id: str) -> AgentEventEmitter:
"""创建事件发射器"""
return AgentEventEmitter(task_id, self)

View File

@ -0,0 +1,28 @@
"""
LangGraph 工作流模块
使用状态图构建混合 Agent 审计流程
"""
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
from .runner import AgentRunner, run_agent_task, LLMService
__all__ = [
# 状态和图
"AuditState",
"create_audit_graph",
"create_audit_graph_with_human",
# 节点
"ReconNode",
"AnalysisNode",
"VerificationNode",
"ReportNode",
"HumanReviewNode",
# Runner
"AgentRunner",
"run_agent_task",
"LLMService",
]

View File

@ -0,0 +1,455 @@
"""
DeepAudit 审计工作流图
使用 LangGraph 构建状态机式的 Agent 协作流程
"""
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
from datetime import datetime
import operator
import logging
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
# ============ 状态定义 ============
class Finding(TypedDict):
"""漏洞发现"""
id: str
vulnerability_type: str
severity: str
title: str
description: str
file_path: Optional[str]
line_start: Optional[int]
code_snippet: Optional[str]
is_verified: bool
confidence: float
source: str
class AuditState(TypedDict):
"""
审计状态
在整个工作流中传递和更新
"""
# 输入
project_root: str
project_info: Dict[str, Any]
config: Dict[str, Any]
task_id: str
# Recon 阶段输出
tech_stack: Dict[str, Any]
entry_points: List[Dict[str, Any]]
high_risk_areas: List[str]
dependencies: Dict[str, Any]
# Analysis 阶段输出
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
# Verification 阶段输出
verified_findings: List[Finding]
false_positives: List[str]
# 控制流
current_phase: str
iteration: int
max_iterations: int
should_continue_analysis: bool
# 消息和事件
messages: Annotated[List[Dict], operator.add]
events: Annotated[List[Dict], operator.add]
# 最终输出
summary: Optional[Dict[str, Any]]
security_score: Optional[int]
error: Optional[str]
# ============ 路由函数 ============
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
"""Recon 后的路由决策"""
# 如果没有发现入口点或高风险区域,直接结束
if not state.get("entry_points") and not state.get("high_risk_areas"):
return "end"
return "analysis"
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
"""Analysis 后的路由决策"""
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
should_continue = state.get("should_continue_analysis", False)
# 如果没有发现,直接生成报告
if not findings:
return "report"
# 如果需要继续分析且未达到最大迭代
if should_continue and iteration < max_iterations:
return "analysis"
# 有发现需要验证
return "verification"
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
"""Verification 后的路由决策"""
# 如果验证发现了误报,可能需要重新分析
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
# 如果误报率太高且还有迭代次数,回到分析
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
return "analysis"
return "report"
# ============ 创建审计图 ============
def create_audit_graph(
recon_node,
analysis_node,
verification_node,
report_node,
checkpointer: Optional[MemorySaver] = None,
) -> StateGraph:
"""
创建审计工作流图
Args:
recon_node: 信息收集节点
analysis_node: 漏洞分析节点
verification_node: 漏洞验证节点
report_node: 报告生成节点
checkpointer: 检查点存储器用于状态持久化
Returns:
编译后的 StateGraph
工作流结构:
START
Recon 信息收集
Analysis 漏洞分析可循环
Verification 漏洞验证可回溯
Report 报告生成
END
"""
# 创建状态图
workflow = StateGraph(AuditState)
# 添加节点
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("report", report_node)
# 设置入口点
workflow.set_entry_point("recon")
# 添加条件边
workflow.add_conditional_edges(
"recon",
route_after_recon,
{
"analysis": "analysis",
"end": END,
}
)
workflow.add_conditional_edges(
"analysis",
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis", # 循环
"report": "report",
}
)
workflow.add_conditional_edges(
"verification",
route_after_verification,
{
"analysis": "analysis", # 回溯
"report": "report",
}
)
# Report -> END
workflow.add_edge("report", END)
# 编译图
if checkpointer:
return workflow.compile(checkpointer=checkpointer)
else:
return workflow.compile()
# ============ 带人机协作的审计图 ============
def create_audit_graph_with_human(
recon_node,
analysis_node,
verification_node,
report_node,
human_review_node,
checkpointer: Optional[MemorySaver] = None,
) -> StateGraph:
"""
创建带人机协作的审计工作流图
在验证阶段后增加人工审核节点
工作流结构:
START
Recon
Analysis
Verification
Human Review 人工审核可跳过
Report
END
"""
workflow = StateGraph(AuditState)
# 添加节点
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("human_review", human_review_node)
workflow.add_node("report", report_node)
workflow.set_entry_point("recon")
workflow.add_conditional_edges(
"recon",
route_after_recon,
{"analysis": "analysis", "end": END}
)
workflow.add_conditional_edges(
"analysis",
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis",
"report": "report",
}
)
# Verification -> Human Review
workflow.add_edge("verification", "human_review")
# Human Review 后的路由
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
# 人工可以决定重新分析或继续
if state.get("should_continue_analysis"):
return "analysis"
return "report"
workflow.add_conditional_edges(
"human_review",
route_after_human,
{"analysis": "analysis", "report": "report"}
)
workflow.add_edge("report", END)
if checkpointer:
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
else:
return workflow.compile()
# ============ 执行器 ============
class AuditGraphRunner:
"""
审计图执行器
封装 LangGraph 工作流的执行
"""
def __init__(
self,
graph: StateGraph,
event_emitter=None,
):
self.graph = graph
self.event_emitter = event_emitter
async def run(
self,
project_root: str,
project_info: Dict[str, Any],
config: Dict[str, Any],
task_id: str,
) -> Dict[str, Any]:
"""
执行审计工作流
Args:
project_root: 项目根目录
project_info: 项目信息
config: 审计配置
task_id: 任务 ID
Returns:
最终状态
"""
# 初始状态
initial_state: AuditState = {
"project_root": project_root,
"project_info": project_info,
"config": config,
"task_id": task_id,
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"current_phase": "start",
"iteration": 0,
"max_iterations": config.get("max_iterations", 3),
"should_continue_analysis": False,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
# 配置
run_config = {
"configurable": {
"thread_id": task_id,
}
}
# 执行图
try:
# 流式执行
async for event in self.graph.astream(initial_state, config=run_config):
# 发射事件
if self.event_emitter:
for node_name, node_state in event.items():
await self.event_emitter.emit_info(
f"节点 {node_name} 完成"
)
# 发射发现事件
if node_name == "analysis" and node_state.get("findings"):
new_findings = node_state["findings"]
await self.event_emitter.emit_info(
f"发现 {len(new_findings)} 个潜在漏洞"
)
# 获取最终状态
final_state = self.graph.get_state(run_config)
return final_state.values
except Exception as e:
logger.error(f"Graph execution failed: {e}", exc_info=True)
raise
async def run_with_human_review(
self,
initial_state: AuditState,
human_feedback_callback,
) -> Dict[str, Any]:
"""
带人机协作的执行
Args:
initial_state: 初始状态
human_feedback_callback: 人工反馈回调函数
Returns:
最终状态
"""
run_config = {
"configurable": {
"thread_id": initial_state["task_id"],
}
}
# 执行到人工审核节点
async for event in self.graph.astream(initial_state, config=run_config):
pass
# 获取当前状态
current_state = self.graph.get_state(run_config)
# 如果在人工审核节点暂停
if current_state.next == ("human_review",):
# 调用人工反馈
human_decision = await human_feedback_callback(current_state.values)
# 更新状态并继续
updated_state = {
**current_state.values,
"should_continue_analysis": human_decision.get("continue_analysis", False),
}
# 继续执行
async for event in self.graph.astream(updated_state, config=run_config):
pass
# 返回最终状态
return self.graph.get_state(run_config).values

View File

@ -0,0 +1,360 @@
"""
LangGraph 节点实现
每个节点封装一个 Agent 的执行逻辑
"""
from typing import Dict, Any, List, Optional
import logging
logger = logging.getLogger(__name__)
# 延迟导入避免循环依赖
def get_audit_state_type():
from .audit_graph import AuditState
return AuditState
class BaseNode:
"""节点基类"""
def __init__(self, agent=None, event_emitter=None):
self.agent = agent
self.event_emitter = event_emitter
async def emit_event(self, event_type: str, message: str, **kwargs):
"""发射事件"""
if self.event_emitter:
try:
await self.event_emitter.emit_info(message)
except Exception as e:
logger.warning(f"Failed to emit event: {e}")
class ReconNode(BaseNode):
"""
信息收集节点
输入: project_root, project_info, config
输出: tech_stack, entry_points, high_risk_areas, dependencies
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行信息收集"""
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
try:
# 调用 Recon Agent
result = await self.agent.run({
"project_info": state["project_info"],
"config": state["config"],
})
if result.success and result.data:
data = result.data
await self.emit_event(
"phase_complete",
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
)
return {
"tech_stack": data.get("tech_stack", {}),
"entry_points": data.get("entry_points", []),
"high_risk_areas": data.get("high_risk_areas", []),
"dependencies": data.get("dependencies", {}),
"current_phase": "recon_complete",
"findings": data.get("initial_findings", []), # 初步发现
"events": [{
"type": "recon_complete",
"data": {
"entry_points_count": len(data.get("entry_points", [])),
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
}
}],
}
else:
return {
"error": result.error or "Recon failed",
"current_phase": "error",
}
except Exception as e:
logger.error(f"Recon node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class AnalysisNode(BaseNode):
"""
漏洞分析节点
输入: tech_stack, entry_points, high_risk_areas, previous findings
输出: findings (累加), should_continue_analysis
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行漏洞分析"""
iteration = state.get("iteration", 0) + 1
await self.emit_event(
"phase_start",
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
)
try:
# 构建分析输入
analysis_input = {
"phase_name": "analysis",
"project_info": state["project_info"],
"config": state["config"],
"plan": {
"high_risk_areas": state.get("high_risk_areas", []),
},
"previous_results": {
"recon": {
"data": {
"tech_stack": state.get("tech_stack", {}),
"entry_points": state.get("entry_points", []),
"high_risk_areas": state.get("high_risk_areas", []),
}
}
},
}
# 调用 Analysis Agent
result = await self.agent.run(analysis_input)
if result.success and result.data:
new_findings = result.data.get("findings", [])
# 判断是否需要继续分析
# 如果这一轮发现了很多问题,可能还有更多
should_continue = (
len(new_findings) >= 5 and
iteration < state.get("max_iterations", 3)
)
await self.emit_event(
"phase_complete",
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
)
return {
"findings": new_findings, # 会自动累加
"iteration": iteration,
"should_continue_analysis": should_continue,
"current_phase": "analysis_complete",
"events": [{
"type": "analysis_iteration",
"data": {
"iteration": iteration,
"findings_count": len(new_findings),
}
}],
}
else:
return {
"iteration": iteration,
"should_continue_analysis": False,
"current_phase": "analysis_complete",
}
except Exception as e:
logger.error(f"Analysis node failed: {e}", exc_info=True)
return {
"error": str(e),
"should_continue_analysis": False,
"current_phase": "error",
}
class VerificationNode(BaseNode):
"""
漏洞验证节点
输入: findings
输出: verified_findings, false_positives
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行漏洞验证"""
findings = state.get("findings", [])
if not findings:
return {
"verified_findings": [],
"false_positives": [],
"current_phase": "verification_complete",
}
await self.emit_event(
"phase_start",
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
)
try:
# 构建验证输入
verification_input = {
"previous_results": {
"analysis": {
"data": {
"findings": findings,
}
}
},
"config": state["config"],
}
# 调用 Verification Agent
result = await self.agent.run(verification_input)
if result.success and result.data:
verified = [f for f in result.data.get("findings", []) if f.get("is_verified")]
false_pos = [f["id"] for f in result.data.get("findings", [])
if f.get("verdict") == "false_positive"]
await self.emit_event(
"phase_complete",
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
)
return {
"verified_findings": verified,
"false_positives": false_pos,
"current_phase": "verification_complete",
"events": [{
"type": "verification_complete",
"data": {
"verified_count": len(verified),
"false_positive_count": len(false_pos),
}
}],
}
else:
return {
"verified_findings": [],
"false_positives": [],
"current_phase": "verification_complete",
}
except Exception as e:
logger.error(f"Verification node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class ReportNode(BaseNode):
"""
报告生成节点
输入: all state
输出: summary, security_score
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""生成审计报告"""
await self.emit_event("phase_start", "📊 生成审计报告")
try:
findings = state.get("findings", [])
verified = state.get("verified_findings", [])
false_positives = state.get("false_positives", [])
# 统计漏洞分布
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
type_counts = {}
for finding in findings:
sev = finding.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
vtype = finding.get("vulnerability_type", "other")
type_counts[vtype] = type_counts.get(vtype, 0) + 1
# 计算安全评分
base_score = 100
deductions = (
severity_counts["critical"] * 25 +
severity_counts["high"] * 15 +
severity_counts["medium"] * 8 +
severity_counts["low"] * 3
)
security_score = max(0, base_score - deductions)
# 生成摘要
summary = {
"total_findings": len(findings),
"verified_count": len(verified),
"false_positive_count": len(false_positives),
"severity_distribution": severity_counts,
"vulnerability_types": type_counts,
"tech_stack": state.get("tech_stack", {}),
"entry_points_analyzed": len(state.get("entry_points", [])),
"high_risk_areas": state.get("high_risk_areas", []),
"iterations": state.get("iteration", 1),
}
await self.emit_event(
"phase_complete",
f"✅ 报告生成完成: 安全评分 {security_score}/100"
)
return {
"summary": summary,
"security_score": security_score,
"current_phase": "complete",
"events": [{
"type": "audit_complete",
"data": {
"security_score": security_score,
"total_findings": len(findings),
"verified_count": len(verified),
}
}],
}
except Exception as e:
logger.error(f"Report node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class HumanReviewNode(BaseNode):
"""
人工审核节点
在此节点暂停等待人工反馈
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
人工审核节点
这个节点会被 interrupt_before 暂停
用户可以
1. 确认发现
2. 标记误报
3. 请求重新分析
"""
await self.emit_event(
"human_review",
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
)
# 返回当前状态,不做修改
# 人工反馈会通过 update_state 传入
return {
"current_phase": "human_review",
"messages": [{
"role": "system",
"content": "等待人工审核",
"findings_for_review": state.get("verified_findings", []),
}],
}

View File

@ -0,0 +1,621 @@
"""
DeepAudit LangGraph Runner
基于 LangGraph Agent 审计执行器
"""
import asyncio
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any, AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from app.models.agent_task import (
AgentTask, AgentEvent, AgentFinding,
AgentTaskStatus, AgentTaskPhase, AgentEventType,
VulnerabilitySeverity, VulnerabilityType, FindingStatus,
)
from app.services.agent.event_manager import EventManager, AgentEventEmitter
from app.services.agent.tools import (
RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool,
PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool,
FileReadTool, FileSearchTool, ListFilesTool,
SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager,
SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool,
TruffleHogTool, OSVScannerTool,
)
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
from app.core.config import settings
from .audit_graph import AuditState, create_audit_graph
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
logger = logging.getLogger(__name__)
class LLMService:
"""LLM 服务封装"""
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
self.model = model or settings.DEFAULT_LLM_MODEL
self.api_key = api_key or settings.LLM_API_KEY
async def chat_completion_raw(
self,
messages: List[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> Dict[str, Any]:
"""调用 LLM 生成响应"""
try:
import litellm
response = await litellm.acompletion(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
)
return {
"content": response.choices[0].message.content,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
} if response.usage else {},
}
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise
class AgentRunner:
"""
DeepAudit LangGraph Agent Runner
基于 LangGraph 状态图的审计执行器
工作流:
START Recon Analysis Verification Report END
"""
def __init__(
self,
db: AsyncSession,
task: AgentTask,
project_root: str,
):
self.db = db
self.task = task
self.project_root = project_root
# 事件管理
self.event_manager = EventManager(db, task.id)
self.event_emitter = AgentEventEmitter(self.event_manager)
# LLM 服务
self.llm_service = LLMService()
# 工具集
self.tools: Dict[str, Any] = {}
# RAG 组件
self.retriever: Optional[CodeRetriever] = None
self.indexer: Optional[CodeIndexer] = None
# 沙箱
self.sandbox_manager: Optional[SandboxManager] = None
# LangGraph
self.graph: Optional[StateGraph] = None
self.checkpointer = MemorySaver()
# 状态
self._cancelled = False
async def initialize(self):
"""初始化 Runner"""
await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...")
# 1. 初始化 RAG 系统
await self._initialize_rag()
# 2. 初始化工具
await self._initialize_tools()
# 3. 构建 LangGraph
await self._build_graph()
await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成")
async def _initialize_rag(self):
"""初始化 RAG 系统"""
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
try:
embedding_service = EmbeddingService(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
api_key=settings.LLM_API_KEY,
base_url=settings.LLM_BASE_URL,
)
self.indexer = CodeIndexer(
embedding_service=embedding_service,
vector_db_path=settings.VECTOR_DB_PATH,
collection_name=f"project_{self.task.project_id}",
)
self.retriever = CodeRetriever(
embedding_service=embedding_service,
vector_db_path=settings.VECTOR_DB_PATH,
collection_name=f"project_{self.task.project_id}",
)
except Exception as e:
logger.warning(f"RAG initialization failed: {e}")
await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}")
async def _initialize_tools(self):
"""初始化工具集"""
await self.event_emitter.emit_info("🔧 初始化 Agent 工具集...")
# 文件工具
self.tools["read_file"] = FileReadTool(self.project_root)
self.tools["search_code"] = FileSearchTool(self.project_root)
self.tools["list_files"] = ListFilesTool(self.project_root)
# RAG 工具
if self.retriever:
self.tools["rag_query"] = RAGQueryTool(self.retriever)
self.tools["security_search"] = SecurityCodeSearchTool(self.retriever)
self.tools["function_context"] = FunctionContextTool(self.retriever)
# 分析工具
self.tools["pattern_match"] = PatternMatchTool(self.project_root)
self.tools["code_analysis"] = CodeAnalysisTool(self.llm_service)
self.tools["dataflow_analysis"] = DataFlowAnalysisTool(self.llm_service)
self.tools["vulnerability_validation"] = VulnerabilityValidationTool(self.llm_service)
# 外部安全工具
self.tools["semgrep_scan"] = SemgrepTool(self.project_root)
self.tools["bandit_scan"] = BanditTool(self.project_root)
self.tools["gitleaks_scan"] = GitleaksTool(self.project_root)
self.tools["trufflehog_scan"] = TruffleHogTool(self.project_root)
self.tools["npm_audit"] = NpmAuditTool(self.project_root)
self.tools["safety_scan"] = SafetyTool(self.project_root)
self.tools["osv_scan"] = OSVScannerTool(self.project_root)
# 沙箱工具
try:
self.sandbox_manager = SandboxManager(
image=settings.SANDBOX_IMAGE,
memory_limit=settings.SANDBOX_MEMORY_LIMIT,
cpu_limit=settings.SANDBOX_CPU_LIMIT,
)
self.tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
self.tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
self.tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
except Exception as e:
logger.warning(f"Sandbox initialization failed: {e}")
await self.event_emitter.emit_info(f"✅ 已加载 {len(self.tools)} 个工具")
async def _build_graph(self):
"""构建 LangGraph 审计图"""
await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...")
# 导入 Agent
from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent
# 创建 Agent 实例
recon_agent = ReconAgent(
llm_service=self.llm_service,
tools=self.tools,
event_emitter=self.event_emitter,
)
analysis_agent = AnalysisAgent(
llm_service=self.llm_service,
tools=self.tools,
event_emitter=self.event_emitter,
)
verification_agent = VerificationAgent(
llm_service=self.llm_service,
tools=self.tools,
event_emitter=self.event_emitter,
)
# 创建节点
recon_node = ReconNode(recon_agent, self.event_emitter)
analysis_node = AnalysisNode(analysis_agent, self.event_emitter)
verification_node = VerificationNode(verification_agent, self.event_emitter)
report_node = ReportNode(None, self.event_emitter)
# 构建图
self.graph = create_audit_graph(
recon_node=recon_node,
analysis_node=analysis_node,
verification_node=verification_node,
report_node=report_node,
checkpointer=self.checkpointer,
)
await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成")
async def run(self) -> Dict[str, Any]:
"""
执行 LangGraph 审计
Returns:
最终状态
"""
import time
start_time = time.time()
try:
# 初始化
await self.initialize()
# 更新任务状态
await self._update_task_status(AgentTaskStatus.RUNNING)
# 1. 索引代码
await self._index_code()
if self._cancelled:
return {"success": False, "error": "任务已取消"}
# 2. 收集项目信息
project_info = await self._collect_project_info()
# 3. 构建初始状态
initial_state: AuditState = {
"project_root": self.project_root,
"project_info": project_info,
"config": self.task.config or {},
"task_id": self.task.id,
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"current_phase": "start",
"iteration": 0,
"max_iterations": (self.task.config or {}).get("max_iterations", 3),
"should_continue_analysis": False,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
# 4. 执行 LangGraph
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
run_config = {
"configurable": {
"thread_id": self.task.id,
}
}
final_state = None
# 流式执行并发射事件
async for event in self.graph.astream(initial_state, config=run_config):
if self._cancelled:
break
# 处理每个节点的输出
for node_name, node_output in event.items():
await self._handle_node_output(node_name, node_output)
# 更新阶段
phase_map = {
"recon": AgentTaskPhase.RECONNAISSANCE,
"analysis": AgentTaskPhase.ANALYSIS,
"verification": AgentTaskPhase.VERIFICATION,
"report": AgentTaskPhase.REPORTING,
}
if node_name in phase_map:
await self._update_task_phase(phase_map[node_name])
final_state = node_output
# 5. 获取最终状态
if not final_state:
graph_state = self.graph.get_state(run_config)
final_state = graph_state.values if graph_state else {}
# 6. 保存发现
findings = final_state.get("findings", [])
await self._save_findings(findings)
# 7. 更新任务摘要
summary = final_state.get("summary", {})
security_score = final_state.get("security_score", 100)
await self._update_task_summary(
total_findings=len(findings),
verified_count=len(final_state.get("verified_findings", [])),
security_score=security_score,
)
# 8. 完成
duration_ms = int((time.time() - start_time) * 1000)
await self._update_task_status(AgentTaskStatus.COMPLETED)
await self.event_emitter.emit_task_complete(
findings_count=len(findings),
duration_ms=duration_ms,
)
return {
"success": True,
"data": {
"findings": findings,
"verified_findings": final_state.get("verified_findings", []),
"summary": summary,
"security_score": security_score,
},
"duration_ms": duration_ms,
}
except asyncio.CancelledError:
await self._update_task_status(AgentTaskStatus.CANCELLED)
return {"success": False, "error": "任务已取消"}
except Exception as e:
logger.error(f"LangGraph run failed: {e}", exc_info=True)
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
await self.event_emitter.emit_error(str(e))
return {"success": False, "error": str(e)}
finally:
await self._cleanup()
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
"""处理节点输出"""
# 发射节点事件
events = output.get("events", [])
for evt in events:
await self.event_emitter.emit_info(
f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}"
)
# 处理新发现
if node_name == "analysis":
new_findings = output.get("findings", [])
if new_findings:
for finding in new_findings[:5]: # 限制事件数量
await self.event_emitter.emit_finding(
title=finding.get("title", "Unknown"),
severity=finding.get("severity", "medium"),
file_path=finding.get("file_path"),
)
# 处理验证结果
if node_name == "verification":
verified = output.get("verified_findings", [])
for v in verified[:5]:
await self.event_emitter.emit_info(
f"✅ 已验证: {v.get('title', 'Unknown')}"
)
# 处理错误
if output.get("error"):
await self.event_emitter.emit_error(output["error"])
async def _index_code(self):
"""索引代码"""
if not self.indexer:
await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引")
return
await self._update_task_phase(AgentTaskPhase.INDEXING)
await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引")
try:
async for progress in self.indexer.index_directory(self.project_root):
if self._cancelled:
return
await self.event_emitter.emit_progress(
progress.processed / max(progress.total, 1) * 100,
f"正在索引: {progress.current_file or 'N/A'}"
)
await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成")
except Exception as e:
logger.warning(f"Code indexing failed: {e}")
await self.event_emitter.emit_warning(f"代码索引失败: {e}")
async def _collect_project_info(self) -> Dict[str, Any]:
"""收集项目信息"""
info = {
"name": self.task.project.name if self.task.project else "unknown",
"root": self.project_root,
"languages": [],
"file_count": 0,
}
try:
exclude_dirs = {
"node_modules", "__pycache__", ".git", "venv", ".venv",
"build", "dist", "target", ".idea", ".vscode",
}
for root, dirs, files in os.walk(self.project_root):
dirs[:] = [d for d in dirs if d not in exclude_dirs]
info["file_count"] += len(files)
lang_map = {
".py": "Python", ".js": "JavaScript", ".ts": "TypeScript",
".java": "Java", ".go": "Go", ".php": "PHP",
".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++",
}
for f in files:
ext = os.path.splitext(f)[1].lower()
if ext in lang_map and lang_map[ext] not in info["languages"]:
info["languages"].append(lang_map[ext])
except Exception as e:
logger.warning(f"Failed to collect project info: {e}")
return info
async def _save_findings(self, findings: List[Dict]):
"""保存发现到数据库"""
severity_map = {
"critical": VulnerabilitySeverity.CRITICAL,
"high": VulnerabilitySeverity.HIGH,
"medium": VulnerabilitySeverity.MEDIUM,
"low": VulnerabilitySeverity.LOW,
"info": VulnerabilitySeverity.INFO,
}
type_map = {
"sql_injection": VulnerabilityType.SQL_INJECTION,
"xss": VulnerabilityType.XSS,
"command_injection": VulnerabilityType.COMMAND_INJECTION,
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
"ssrf": VulnerabilityType.SSRF,
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
"deserialization": VulnerabilityType.INSECURE_DESERIALIZATION,
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
}
for finding in findings:
try:
db_finding = AgentFinding(
id=str(uuid.uuid4()),
task_id=self.task.id,
vulnerability_type=type_map.get(
finding.get("vulnerability_type", "other"),
VulnerabilityType.OTHER
),
severity=severity_map.get(
finding.get("severity", "medium"),
VulnerabilitySeverity.MEDIUM
),
title=finding.get("title", "Unknown"),
description=finding.get("description", ""),
file_path=finding.get("file_path"),
line_start=finding.get("line_start"),
line_end=finding.get("line_end"),
code_snippet=finding.get("code_snippet"),
source=finding.get("source"),
sink=finding.get("sink"),
suggestion=finding.get("suggestion") or finding.get("recommendation"),
is_verified=finding.get("is_verified", False),
confidence=finding.get("confidence", 0.5),
poc=finding.get("poc"),
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN,
)
self.db.add(db_finding)
except Exception as e:
logger.warning(f"Failed to save finding: {e}")
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to commit findings: {e}")
await self.db.rollback()
async def _update_task_status(
self,
status: AgentTaskStatus,
error: Optional[str] = None
):
"""更新任务状态"""
self.task.status = status
if status == AgentTaskStatus.RUNNING:
self.task.started_at = datetime.now(timezone.utc)
elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
self.task.finished_at = datetime.now(timezone.utc)
if error:
self.task.error_message = error
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update task status: {e}")
async def _update_task_phase(self, phase: AgentTaskPhase):
"""更新任务阶段"""
self.task.current_phase = phase
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update task phase: {e}")
async def _update_task_summary(
self,
total_findings: int,
verified_count: int,
security_score: int,
):
"""更新任务摘要"""
self.task.total_findings = total_findings
self.task.verified_findings = verified_count
self.task.security_score = security_score
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update task summary: {e}")
async def _cleanup(self):
"""清理资源"""
try:
if self.sandbox_manager:
await self.sandbox_manager.cleanup()
await self.event_manager.close()
except Exception as e:
logger.warning(f"Cleanup error: {e}")
def cancel(self):
"""取消任务"""
self._cancelled = True
# 便捷函数
async def run_agent_task(
db: AsyncSession,
task: AgentTask,
project_root: str,
) -> Dict[str, Any]:
"""
运行 Agent 审计任务
Args:
db: 数据库会话
task: Agent 任务
project_root: 项目根目录
Returns:
审计结果
"""
runner = AgentRunner(db, task, project_root)
return await runner.run()

View File

@ -0,0 +1,20 @@
"""
Agent 提示词模块
"""
from .system_prompts import (
ORCHESTRATOR_SYSTEM_PROMPT,
ANALYSIS_SYSTEM_PROMPT,
VERIFICATION_SYSTEM_PROMPT,
PLANNING_PROMPT,
REPORTING_PROMPT,
)
__all__ = [
"ORCHESTRATOR_SYSTEM_PROMPT",
"ANALYSIS_SYSTEM_PROMPT",
"VERIFICATION_SYSTEM_PROMPT",
"PLANNING_PROMPT",
"REPORTING_PROMPT",
]

View File

@ -0,0 +1,170 @@
"""
Agent 系统提示词
"""
# 编排 Agent 系统提示词
ORCHESTRATOR_SYSTEM_PROMPT = """你是一个专业的代码安全审计 Agent负责自主分析代码并发现安全漏洞。
## 你的职责
1. 分析项目代码制定审计计划
2. 使用工具深入分析代码
3. 发现并验证安全漏洞
4. 生成详细的漏洞报告
## 审计流程
1. **规划阶段**: 分析项目结构识别高风险区域制定审计计划
2. **索引阶段**: 等待代码索引完成
3. **分析阶段**: 使用工具进行深度代码分析
4. **验证阶段**: 在沙箱中验证发现的漏洞
5. **报告阶段**: 整理发现生成报告
## 重点关注的漏洞类型
- SQL 注入包括 ORM 注入
- XSS 跨站脚本反射型存储型DOM型
- 命令注入和代码注入
- 路径遍历和任意文件访问
- SSRF 服务端请求伪造
- XXE XML 外部实体注入
- 不安全的反序列化
- 认证和授权绕过
- 敏感信息泄露硬编码密钥日志泄露
- 业务逻辑漏洞
- IDOR 不安全的直接对象引用
## 分析方法
1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码
2. **语义搜索**: 使用 rag_query 查找相关上下文
3. **深度分析**: 对可疑代码使用 code_analysis 深入分析
4. **数据流追踪**: 追踪用户输入到危险函数的路径
5. **漏洞验证**: 在沙箱中验证发现的漏洞
## 工作原则
- 系统性: 不遗漏任何可能的攻击面
- 精准性: 减少误报每个发现都要有充分证据
- 深入性: 不只看表面要理解代码逻辑
- 可操作性: 提供具体的修复建议
## 输出要求
发现漏洞时提供:
- 漏洞类型和严重程度
- 具体位置文件行号
- 漏洞描述和成因
- 利用方式和影响
- 修复建议和示例代码
请开始审计工作使用可用的工具进行分析"""
# 分析 Agent 系统提示词
ANALYSIS_SYSTEM_PROMPT = """你是一个专注于代码漏洞分析的安全专家。
## 你的任务
深入分析代码发现安全漏洞你需要:
1. 识别危险的代码模式
2. 追踪数据流从用户输入到危险函数
3. 判断漏洞是否可利用
4. 评估漏洞的严重程度
## 可用工具
- rag_query: 语义搜索相关代码
- pattern_match: 快速模式匹配
- code_analysis: LLM 深度分析
- read_file: 读取文件内容
- search_code: 关键字搜索
- dataflow_analysis: 数据流分析
- vulnerability_validation: 漏洞验证
## 分析策略
1. 先全局后局部先了解整体架构再深入细节
2. 先快后深先快速扫描再深入可疑点
3. 追踪数据流用户输入 处理逻辑 危险函数
4. 验证每个发现确保不是误报
## 严重程度评估标准
- **Critical**: 可直接导致系统被控制或大规模数据泄露
- **High**: 可导致敏感数据泄露或重要功能被绕过
- **Medium**: 可导致部分数据泄露或需要特定条件利用
- **Low**: 影响有限或利用条件苛刻
请开始分析专注于发现真实的安全漏洞"""
# 验证 Agent 系统提示词
VERIFICATION_SYSTEM_PROMPT = """你是一个专注于漏洞验证的安全专家。
## 你的任务
验证发现的漏洞是否真实存在判断是否为误报
## 验证方法
1. **代码审查**: 仔细分析漏洞代码和上下文
2. **构造 Payload**: 设计能触发漏洞的输入
3. **沙箱测试**: 在隔离环境中测试漏洞
4. **分析结果**: 判断漏洞是否可利用
## 可用工具
- sandbox_exec: 在沙箱中执行命令
- sandbox_http: 发送 HTTP 请求
- verify_vulnerability: 自动验证漏洞
- vulnerability_validation: 深度验证分析
## 验证原则
- 安全第一所有测试在沙箱中进行
- 证据充分验证结果要有明确证据
- 谨慎判断不确定时标记为需要人工审核
## 输出要求
- 验证结果确认/可能/误报
- 验证方法使用的测试方法
- 证据支持判断的具体证据
- PoC如果确认可复现的测试代码
请开始验证工作"""
# 规划提示词
PLANNING_PROMPT = """基于以下项目信息,制定安全审计计划。
## 项目信息
- 名称: {project_name}
- 语言: {languages}
- 文件数量: {file_count}
- 目录结构: {directory_structure}
## 请输出审计计划
包含以下内容JSON格式:
```json
{{
"high_risk_areas": ["高风险目录/文件列表"],
"focus_vulnerabilities": ["重点关注的漏洞类型"],
"audit_order": ["审计顺序"],
"estimated_steps": "预计步骤数",
"special_attention": ["特别注意事项"]
}}
```
## 高风险区域识别原则
1. 用户认证和授权相关代码
2. 数据库操作和 ORM 使用
3. 文件上传和下载功能
4. API 接口和输入处理
5. 第三方服务调用
6. 配置文件和环境变量
7. 加密和密钥管理"""
# 报告生成提示词
REPORTING_PROMPT = """基于审计发现,生成安全审计报告摘要。
## 审计发现
{findings}
## 统计信息
- 总发现数: {total_findings}
- 已验证: {verified_count}
- 严重程度分布: {severity_distribution}
## 请输出报告摘要
包含以下内容:
1. 整体安全评估
2. 主要风险点
3. 优先修复建议
4. 安全改进建议
请用简洁专业的语言描述"""

View File

@ -0,0 +1,61 @@
"""
Agent 工具集
提供 LangChain Agent 使用的各种工具
包括内置工具和外部安全工具
"""
from .base import AgentTool, ToolResult
from .rag_tool import RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool
from .pattern_tool import PatternMatchTool
from .code_analysis_tool import CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool
from .file_tool import FileReadTool, FileSearchTool, ListFilesTool
from .sandbox_tool import SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager
# 外部安全工具
from .external_tools import (
SemgrepTool,
BanditTool,
GitleaksTool,
NpmAuditTool,
SafetyTool,
TruffleHogTool,
OSVScannerTool,
)
__all__ = [
# 基础
"AgentTool",
"ToolResult",
# RAG 工具
"RAGQueryTool",
"SecurityCodeSearchTool",
"FunctionContextTool",
# 代码分析
"PatternMatchTool",
"CodeAnalysisTool",
"DataFlowAnalysisTool",
"VulnerabilityValidationTool",
# 文件操作
"FileReadTool",
"FileSearchTool",
"ListFilesTool",
# 沙箱
"SandboxTool",
"SandboxHttpTool",
"VulnerabilityVerifyTool",
"SandboxManager",
# 外部安全工具
"SemgrepTool",
"BanditTool",
"GitleaksTool",
"NpmAuditTool",
"SafetyTool",
"TruffleHogTool",
"OSVScannerTool",
]

View File

@ -0,0 +1,156 @@
"""
Agent 工具基类
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type
from dataclasses import dataclass, field
from pydantic import BaseModel
import logging
import time
logger = logging.getLogger(__name__)
@dataclass
class ToolResult:
"""工具执行结果"""
success: bool
data: Any = None
error: Optional[str] = None
duration_ms: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
"data": self.data,
"error": self.error,
"duration_ms": self.duration_ms,
"metadata": self.metadata,
}
def to_string(self, max_length: int = 5000) -> str:
"""转换为字符串(用于 LLM 输出)"""
if not self.success:
return f"Error: {self.error}"
if isinstance(self.data, str):
result = self.data
elif isinstance(self.data, (dict, list)):
import json
result = json.dumps(self.data, ensure_ascii=False, indent=2)
else:
result = str(self.data)
if len(result) > max_length:
result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)"
return result
class AgentTool(ABC):
"""
Agent 工具基类
所有工具需要继承此类并实现必要的方法
"""
def __init__(self):
self._call_count = 0
self._total_duration_ms = 0
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述(用于 Agent 理解工具功能)"""
pass
@property
def args_schema(self) -> Optional[Type[BaseModel]]:
"""参数 SchemaPydantic 模型)"""
return None
@abstractmethod
async def _execute(self, **kwargs) -> ToolResult:
"""执行工具(子类实现)"""
pass
async def execute(self, **kwargs) -> ToolResult:
"""执行工具(带计时和日志)"""
start_time = time.time()
try:
logger.debug(f"Tool '{self.name}' executing with args: {kwargs}")
result = await self._execute(**kwargs)
except Exception as e:
logger.error(f"Tool '{self.name}' error: {e}", exc_info=True)
result = ToolResult(
success=False,
error=str(e),
)
duration_ms = int((time.time() - start_time) * 1000)
result.duration_ms = duration_ms
self._call_count += 1
self._total_duration_ms += duration_ms
logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}")
return result
def get_langchain_tool(self):
"""转换为 LangChain Tool"""
from langchain.tools import Tool, StructuredTool
import asyncio
def sync_wrapper(**kwargs):
"""同步包装器"""
loop = asyncio.get_event_loop()
if loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.execute(**kwargs))
result = future.result()
else:
result = asyncio.run(self.execute(**kwargs))
return result.to_string()
async def async_wrapper(**kwargs):
"""异步包装器"""
result = await self.execute(**kwargs)
return result.to_string()
if self.args_schema:
return StructuredTool(
name=self.name,
description=self.description,
func=sync_wrapper,
coroutine=async_wrapper,
args_schema=self.args_schema,
)
else:
return Tool(
name=self.name,
description=self.description,
func=lambda x: sync_wrapper(query=x),
coroutine=lambda x: async_wrapper(query=x),
)
@property
def stats(self) -> Dict[str, Any]:
"""工具使用统计"""
return {
"name": self.name,
"call_count": self._call_count,
"total_duration_ms": self._total_duration_ms,
"avg_duration_ms": self._total_duration_ms // max(1, self._call_count),
}

View File

@ -0,0 +1,427 @@
"""
代码分析工具
使用 LLM 深度分析代码安全问题
"""
import json
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
class CodeAnalysisInput(BaseModel):
"""代码分析输入"""
code: str = Field(description="要分析的代码内容")
file_path: str = Field(default="unknown", description="文件路径")
language: str = Field(default="python", description="编程语言")
focus: Optional[str] = Field(
default=None,
description="重点关注的漏洞类型,如 sql_injection, xss, command_injection"
)
context: Optional[str] = Field(
default=None,
description="额外的上下文信息,如相关的其他代码片段"
)
class CodeAnalysisTool(AgentTool):
"""
代码分析工具
使用 LLM 对代码进行深度安全分析
"""
def __init__(self, llm_service):
"""
初始化代码分析工具
Args:
llm_service: LLM 服务实例
"""
super().__init__()
self.llm_service = llm_service
@property
def name(self) -> str:
return "code_analysis"
@property
def description(self) -> str:
return """深度分析代码安全问题。
使用 LLM 对代码进行全面的安全审计识别潜在漏洞
使用场景:
- 对疑似有问题的代码进行深入分析
- 分析复杂的业务逻辑漏洞
- 追踪数据流和污点传播
- 生成详细的漏洞报告和修复建议
输入:
- code: 要分析的代码
- file_path: 文件路径
- language: 编程语言
- focus: 可选重点关注的漏洞类型
- context: 可选额外的上下文代码
这个工具会消耗较多的 Token建议在确认有疑似问题后使用"""
@property
def args_schema(self):
return CodeAnalysisInput
async def _execute(
self,
code: str,
file_path: str = "unknown",
language: str = "python",
focus: Optional[str] = None,
context: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行代码分析"""
try:
# 构建分析结果
analysis = await self.llm_service.analyze_code(code, language)
issues = analysis.get("issues", [])
if not issues:
return ToolResult(
success=True,
data="代码分析完成,未发现明显的安全问题。\n\n"
f"质量评分: {analysis.get('quality_score', 'N/A')}\n"
f"文件: {file_path}",
metadata={
"file_path": file_path,
"issues_count": 0,
"quality_score": analysis.get("quality_score"),
}
)
# 格式化输出
output_parts = [f"🔍 代码分析结果 - {file_path}\n"]
output_parts.append(f"发现 {len(issues)} 个问题:\n")
for i, issue in enumerate(issues):
severity_icon = {
"critical": "🔴",
"high": "🟠",
"medium": "🟡",
"low": "🟢"
}.get(issue.get("severity", ""), "")
output_parts.append(f"\n{severity_icon} 问题 {i+1}: {issue.get('title', 'Unknown')}")
output_parts.append(f" 类型: {issue.get('type', 'unknown')}")
output_parts.append(f" 严重程度: {issue.get('severity', 'unknown')}")
output_parts.append(f" 行号: {issue.get('line', 'N/A')}")
output_parts.append(f" 描述: {issue.get('description', '')}")
if issue.get("code_snippet"):
output_parts.append(f" 代码片段:\n ```\n {issue.get('code_snippet')}\n ```")
if issue.get("suggestion"):
output_parts.append(f" 修复建议: {issue.get('suggestion')}")
if issue.get("ai_explanation"):
output_parts.append(f" AI解释: {issue.get('ai_explanation')}")
output_parts.append(f"\n质量评分: {analysis.get('quality_score', 'N/A')}/100")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"file_path": file_path,
"issues_count": len(issues),
"quality_score": analysis.get("quality_score"),
"issues": issues,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"代码分析失败: {str(e)}",
)
class DataFlowAnalysisInput(BaseModel):
"""数据流分析输入"""
source_code: str = Field(description="包含数据源的代码")
sink_code: Optional[str] = Field(default=None, description="包含数据汇的代码(如危险函数)")
variable_name: str = Field(description="要追踪的变量名")
file_path: str = Field(default="unknown", description="文件路径")
class DataFlowAnalysisTool(AgentTool):
"""
数据流分析工具
追踪变量从源到汇的数据流
"""
def __init__(self, llm_service):
super().__init__()
self.llm_service = llm_service
@property
def name(self) -> str:
return "dataflow_analysis"
@property
def description(self) -> str:
return """分析代码中的数据流,追踪变量从源(如用户输入)到汇(如危险函数)的路径。
使用场景:
- 追踪用户输入如何流向危险函数
- 分析变量是否经过净化处理
- 识别污点传播路径
输入:
- source_code: 包含数据源的代码
- sink_code: 包含数据汇的代码可选
- variable_name: 要追踪的变量名
- file_path: 文件路径"""
@property
def args_schema(self):
return DataFlowAnalysisInput
async def _execute(
self,
source_code: str,
variable_name: str,
sink_code: Optional[str] = None,
file_path: str = "unknown",
**kwargs
) -> ToolResult:
"""执行数据流分析"""
try:
# 构建分析 prompt
analysis_prompt = f"""分析以下代码中变量 '{variable_name}' 的数据流。
源代码:
```
{source_code}
```
"""
if sink_code:
analysis_prompt += f"""
汇代码可能的危险函数:
```
{sink_code}
```
"""
analysis_prompt += f"""
请分析:
1. 变量 '{variable_name}' 的来源是什么用户输入配置数据库等
2. 变量在传递过程中是否经过了净化/验证
3. 变量最终流向了哪些危险函数
4. 是否存在安全风险
请返回 JSON 格式的分析结果包含:
- source_type: 数据源类型
- sanitized: 是否经过净化
- sanitization_methods: 使用的净化方法
- dangerous_sinks: 流向的危险函数列表
- risk_level: 风险等级 (high/medium/low/none)
- explanation: 详细解释
- recommendation: 建议
"""
# 调用 LLM 分析
# 这里使用 analyze_code_with_custom_prompt
result = await self.llm_service.analyze_code_with_custom_prompt(
code=source_code,
language="text",
custom_prompt=analysis_prompt,
)
# 格式化输出
output_parts = [f"📊 数据流分析结果 - 变量: {variable_name}\n"]
if isinstance(result, dict):
if result.get("source_type"):
output_parts.append(f"数据源: {result.get('source_type')}")
if result.get("sanitized") is not None:
sanitized = "✅ 是" if result.get("sanitized") else "❌ 否"
output_parts.append(f"是否净化: {sanitized}")
if result.get("sanitization_methods"):
output_parts.append(f"净化方法: {', '.join(result.get('sanitization_methods', []))}")
if result.get("dangerous_sinks"):
output_parts.append(f"危险函数: {', '.join(result.get('dangerous_sinks', []))}")
if result.get("risk_level"):
risk_icons = {"high": "🔴", "medium": "🟠", "low": "🟡", "none": "🟢"}
icon = risk_icons.get(result.get("risk_level", ""), "")
output_parts.append(f"风险等级: {icon} {result.get('risk_level', '').upper()}")
if result.get("explanation"):
output_parts.append(f"\n分析: {result.get('explanation')}")
if result.get("recommendation"):
output_parts.append(f"\n建议: {result.get('recommendation')}")
else:
output_parts.append(str(result))
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"variable": variable_name,
"file_path": file_path,
"analysis": result,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"数据流分析失败: {str(e)}",
)
class VulnerabilityValidationInput(BaseModel):
"""漏洞验证输入"""
code: str = Field(description="可能存在漏洞的代码")
vulnerability_type: str = Field(description="漏洞类型")
file_path: str = Field(default="unknown", description="文件路径")
line_number: Optional[int] = Field(default=None, description="行号")
context: Optional[str] = Field(default=None, description="额外上下文")
class VulnerabilityValidationTool(AgentTool):
"""
漏洞验证工具
验证疑似漏洞是否真实存在
"""
def __init__(self, llm_service):
super().__init__()
self.llm_service = llm_service
@property
def name(self) -> str:
return "vulnerability_validation"
@property
def description(self) -> str:
return """验证疑似漏洞是否真实存在。
对发现的潜在漏洞进行深入分析判断是否为真正的安全问题
输入:
- code: 包含疑似漏洞的代码
- vulnerability_type: 漏洞类型 sql_injection, xss
- file_path: 文件路径
- line_number: 可选行号
- context: 可选额外的上下文代码
输出:
- 验证结果确认/可能/误报
- 详细分析
- 利用条件
- PoC 思路如果确认存在漏洞"""
@property
def args_schema(self):
return VulnerabilityValidationInput
async def _execute(
self,
code: str,
vulnerability_type: str,
file_path: str = "unknown",
line_number: Optional[int] = None,
context: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行漏洞验证"""
try:
validation_prompt = f"""你是一个专业的安全研究员,请验证以下代码中是否真的存在 {vulnerability_type} 漏洞。
代码:
```
{code}
```
{f'额外上下文:' + chr(10) + '```' + chr(10) + context + chr(10) + '```' if context else ''}
请分析:
1. 这段代码是否真的存在 {vulnerability_type} 漏洞
2. 漏洞的利用条件是什么
3. 攻击者如何利用这个漏洞
4. 这是否可能是误报为什么
请返回 JSON 格式:
{{
"is_vulnerable": true/false/null (null表示无法确定),
"confidence": 0.0-1.0,
"verdict": "confirmed/likely/unlikely/false_positive",
"exploitation_conditions": ["条件1", "条件2"],
"attack_vector": "攻击向量描述",
"poc_idea": "PoC思路如果存在漏洞",
"false_positive_reason": "如果是误报,说明原因",
"detailed_analysis": "详细分析"
}}
"""
result = await self.llm_service.analyze_code_with_custom_prompt(
code=code,
language="text",
custom_prompt=validation_prompt,
)
# 格式化输出
output_parts = [f"🔎 漏洞验证结果 - {vulnerability_type}\n"]
output_parts.append(f"文件: {file_path}")
if line_number:
output_parts.append(f"行号: {line_number}")
output_parts.append("")
if isinstance(result, dict):
# 验证结果
verdict_icons = {
"confirmed": "🔴 确认存在漏洞",
"likely": "🟠 可能存在漏洞",
"unlikely": "🟡 可能是误报",
"false_positive": "🟢 误报",
}
verdict = result.get("verdict", "unknown")
output_parts.append(f"判定: {verdict_icons.get(verdict, verdict)}")
if result.get("confidence"):
output_parts.append(f"置信度: {result.get('confidence') * 100:.0f}%")
if result.get("exploitation_conditions"):
output_parts.append(f"\n利用条件:")
for cond in result.get("exploitation_conditions", []):
output_parts.append(f" - {cond}")
if result.get("attack_vector"):
output_parts.append(f"\n攻击向量: {result.get('attack_vector')}")
if result.get("poc_idea") and verdict in ["confirmed", "likely"]:
output_parts.append(f"\nPoC思路: {result.get('poc_idea')}")
if result.get("false_positive_reason") and verdict in ["unlikely", "false_positive"]:
output_parts.append(f"\n误报原因: {result.get('false_positive_reason')}")
if result.get("detailed_analysis"):
output_parts.append(f"\n详细分析:\n{result.get('detailed_analysis')}")
else:
output_parts.append(str(result))
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"vulnerability_type": vulnerability_type,
"file_path": file_path,
"line_number": line_number,
"validation": result,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"漏洞验证失败: {str(e)}",
)

View File

@ -0,0 +1,948 @@
"""
外部安全工具集成
集成 SemgrepBanditGitleaksTruffleHognpm audit 等专业安全工具
"""
import asyncio
import json
import logging
import os
import tempfile
import shutil
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from dataclasses import dataclass
from .base import AgentTool, ToolResult
logger = logging.getLogger(__name__)
# ============ Semgrep 工具 ============
class SemgrepInput(BaseModel):
"""Semgrep 扫描输入"""
target_path: str = Field(description="要扫描的目录或文件路径(相对于项目根目录)")
rules: Optional[str] = Field(
default="auto",
description="规则集: auto, p/security-audit, p/owasp-top-ten, p/r2c-security-audit, 或自定义规则文件路径"
)
severity: Optional[str] = Field(
default=None,
description="过滤严重程度: ERROR, WARNING, INFO"
)
max_results: int = Field(default=50, description="最大返回结果数")
class SemgrepTool(AgentTool):
"""
Semgrep 静态分析工具
Semgrep 是一款快速轻量级的静态分析工具支持多种编程语言
提供丰富的安全规则库可以检测各种安全漏洞
官方规则集:
- p/security-audit: 综合安全审计
- p/owasp-top-ten: OWASP Top 10 漏洞
- p/r2c-security-audit: R2C 安全审计规则
- p/python: Python 特定规则
- p/javascript: JavaScript 特定规则
"""
AVAILABLE_RULESETS = [
"auto",
"p/security-audit",
"p/owasp-top-ten",
"p/r2c-security-audit",
"p/python",
"p/javascript",
"p/typescript",
"p/java",
"p/go",
"p/php",
"p/ruby",
"p/secrets",
"p/sql-injection",
"p/xss",
"p/command-injection",
]
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "semgrep_scan"
@property
def description(self) -> str:
return """使用 Semgrep 进行静态安全分析。
Semgrep 是业界领先的静态分析工具支持 30+ 种编程语言
可用规则集:
- auto: 自动选择最佳规则
- p/security-audit: 综合安全审计
- p/owasp-top-ten: OWASP Top 10 漏洞检测
- p/secrets: 密钥泄露检测
- p/sql-injection: SQL 注入检测
- p/xss: XSS 检测
- p/command-injection: 命令注入检测
使用场景:
- 快速全面的代码安全扫描
- 检测常见安全漏洞模式
- 遵循行业安全标准审计"""
@property
def args_schema(self):
return SemgrepInput
async def _execute(
self,
target_path: str = ".",
rules: str = "auto",
severity: Optional[str] = None,
max_results: int = 50,
**kwargs
) -> ToolResult:
"""执行 Semgrep 扫描"""
# 检查 semgrep 是否可用
if not await self._check_semgrep():
return ToolResult(
success=False,
error="Semgrep 未安装。请使用 'pip install semgrep' 安装。",
)
# 构建完整路径
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许扫描项目目录外的路径",
)
# 构建命令
cmd = ["semgrep", "--json", "--quiet"]
if rules == "auto":
cmd.extend(["--config", "auto"])
elif rules.startswith("p/"):
cmd.extend(["--config", rules])
else:
cmd.extend(["--config", rules])
if severity:
cmd.extend(["--severity", severity])
cmd.append(full_path)
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.project_root,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
if proc.returncode not in [0, 1]: # 1 means findings were found
return ToolResult(
success=False,
error=f"Semgrep 执行失败: {stderr.decode()[:500]}",
)
# 解析结果
try:
results = json.loads(stdout.decode())
except json.JSONDecodeError:
return ToolResult(
success=False,
error="无法解析 Semgrep 输出",
)
findings = results.get("results", [])[:max_results]
if not findings:
return ToolResult(
success=True,
data=f"Semgrep 扫描完成,未发现安全问题 (规则集: {rules})",
metadata={"findings_count": 0, "rules": rules}
)
# 格式化输出
output_parts = [f"🔍 Semgrep 扫描结果 (规则集: {rules})\n"]
output_parts.append(f"发现 {len(findings)} 个问题:\n")
severity_icons = {"ERROR": "🔴", "WARNING": "🟠", "INFO": "🟡"}
for i, finding in enumerate(findings[:max_results]):
sev = finding.get("extra", {}).get("severity", "INFO")
icon = severity_icons.get(sev, "")
output_parts.append(f"\n{icon} [{sev}] {finding.get('check_id', 'unknown')}")
output_parts.append(f" 文件: {finding.get('path', '')}:{finding.get('start', {}).get('line', 0)}")
output_parts.append(f" 消息: {finding.get('extra', {}).get('message', '')[:200]}")
# 代码片段
lines = finding.get("extra", {}).get("lines", "")
if lines:
output_parts.append(f" 代码: {lines[:150]}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"findings_count": len(findings),
"rules": rules,
"findings": findings[:10],
}
)
except asyncio.TimeoutError:
return ToolResult(success=False, error="Semgrep 扫描超时")
except Exception as e:
return ToolResult(success=False, error=f"Semgrep 执行错误: {str(e)}")
async def _check_semgrep(self) -> bool:
"""检查 Semgrep 是否可用"""
try:
proc = await asyncio.create_subprocess_exec(
"semgrep", "--version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
return proc.returncode == 0
except:
return False
# ============ Bandit 工具 (Python) ============
class BanditInput(BaseModel):
"""Bandit 扫描输入"""
target_path: str = Field(default=".", description="要扫描的 Python 目录或文件")
severity: str = Field(default="medium", description="最低严重程度: low, medium, high")
confidence: str = Field(default="medium", description="最低置信度: low, medium, high")
max_results: int = Field(default=50, description="最大返回结果数")
class BanditTool(AgentTool):
"""
Bandit Python 安全扫描工具
Bandit 是专门用于 Python 代码的安全分析工具
可以检测常见的 Python 安全问题
- 硬编码密码
- SQL 注入
- 命令注入
- 不安全的随机数生成
- 不安全的反序列化
"""
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "bandit_scan"
@property
def description(self) -> str:
return """使用 Bandit 扫描 Python 代码的安全问题。
Bandit Python 专用的安全分析工具 OpenStack 安全团队开发
检测项目:
- B101: assert 使用
- B102: exec 使用
- B103-B108: 文件权限问题
- B301-B312: pickle/yaml 反序列化
- B501-B508: SSL/TLS 问题
- B601-B608: shell/SQL 注入
- B701-B703: Jinja2 模板问题
仅适用于 Python 项目"""
@property
def args_schema(self):
return BanditInput
async def _execute(
self,
target_path: str = ".",
severity: str = "medium",
confidence: str = "medium",
max_results: int = 50,
**kwargs
) -> ToolResult:
"""执行 Bandit 扫描"""
if not await self._check_bandit():
return ToolResult(
success=False,
error="Bandit 未安装。请使用 'pip install bandit' 安装。",
)
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(success=False, error="安全错误:路径越界")
# 构建命令
severity_map = {"low": "l", "medium": "m", "high": "h"}
confidence_map = {"low": "l", "medium": "m", "high": "h"}
cmd = [
"bandit", "-r", "-f", "json",
"-ll" if severity == "low" else f"-l{severity_map.get(severity, 'm')}",
f"-i{confidence_map.get(confidence, 'm')}",
full_path
]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
try:
results = json.loads(stdout.decode())
except json.JSONDecodeError:
return ToolResult(success=False, error="无法解析 Bandit 输出")
findings = results.get("results", [])[:max_results]
if not findings:
return ToolResult(
success=True,
data="Bandit 扫描完成,未发现 Python 安全问题",
metadata={"findings_count": 0}
)
output_parts = ["🐍 Bandit Python 安全扫描结果\n"]
output_parts.append(f"发现 {len(findings)} 个问题:\n")
severity_icons = {"HIGH": "🔴", "MEDIUM": "🟠", "LOW": "🟡"}
for finding in findings:
sev = finding.get("issue_severity", "LOW")
icon = severity_icons.get(sev, "")
output_parts.append(f"\n{icon} [{sev}] {finding.get('test_id', '')}: {finding.get('test_name', '')}")
output_parts.append(f" 文件: {finding.get('filename', '')}:{finding.get('line_number', 0)}")
output_parts.append(f" 消息: {finding.get('issue_text', '')[:200]}")
output_parts.append(f" 代码: {finding.get('code', '')[:100]}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={"findings_count": len(findings), "findings": findings[:10]}
)
except asyncio.TimeoutError:
return ToolResult(success=False, error="Bandit 扫描超时")
except Exception as e:
return ToolResult(success=False, error=f"Bandit 执行错误: {str(e)}")
async def _check_bandit(self) -> bool:
try:
proc = await asyncio.create_subprocess_exec(
"bandit", "--version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
return proc.returncode == 0
except:
return False
# ============ Gitleaks 工具 ============
class GitleaksInput(BaseModel):
"""Gitleaks 扫描输入"""
target_path: str = Field(default=".", description="要扫描的目录")
no_git: bool = Field(default=True, description="不使用 git history仅扫描文件")
max_results: int = Field(default=50, description="最大返回结果数")
class GitleaksTool(AgentTool):
"""
Gitleaks 密钥泄露检测工具
Gitleaks 是一款专门用于检测代码中硬编码密钥的工具
可以检测
- API Keys (AWS, GCP, Azure, GitHub, etc.)
- 私钥 (RSA, SSH, PGP)
- 数据库凭据
- OAuth tokens
- JWT secrets
"""
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "gitleaks_scan"
@property
def description(self) -> str:
return """使用 Gitleaks 检测代码中的密钥泄露。
Gitleaks 是专业的密钥检测工具支持 150+ 种密钥类型
检测类型:
- AWS Access Keys / Secret Keys
- GCP API Keys / Service Account Keys
- Azure Credentials
- GitHub / GitLab Tokens
- Private Keys (RSA, SSH, PGP)
- Database Connection Strings
- JWT Secrets
- Slack / Discord Tokens
- 等等...
建议在代码审计早期使用此工具"""
@property
def args_schema(self):
return GitleaksInput
async def _execute(
self,
target_path: str = ".",
no_git: bool = True,
max_results: int = 50,
**kwargs
) -> ToolResult:
"""执行 Gitleaks 扫描"""
if not await self._check_gitleaks():
return ToolResult(
success=False,
error="Gitleaks 未安装。请从 https://github.com/gitleaks/gitleaks 安装。",
)
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(success=False, error="安全错误:路径越界")
cmd = ["gitleaks", "detect", "--source", full_path, "-f", "json"]
if no_git:
cmd.append("--no-git")
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
# Gitleaks returns 1 if secrets found
if proc.returncode not in [0, 1]:
return ToolResult(success=False, error=f"Gitleaks 执行失败: {stderr.decode()[:300]}")
if not stdout.strip():
return ToolResult(
success=True,
data="🔐 Gitleaks 扫描完成,未发现密钥泄露",
metadata={"findings_count": 0}
)
try:
findings = json.loads(stdout.decode())
except json.JSONDecodeError:
findings = []
if not findings:
return ToolResult(
success=True,
data="🔐 Gitleaks 扫描完成,未发现密钥泄露",
metadata={"findings_count": 0}
)
findings = findings[:max_results]
output_parts = ["🔐 Gitleaks 密钥泄露检测结果\n"]
output_parts.append(f"⚠️ 发现 {len(findings)} 处密钥泄露!\n")
for i, finding in enumerate(findings):
output_parts.append(f"\n🔴 [{i+1}] {finding.get('RuleID', 'unknown')}")
output_parts.append(f" 描述: {finding.get('Description', '')}")
output_parts.append(f" 文件: {finding.get('File', '')}:{finding.get('StartLine', 0)}")
# 部分遮盖密钥
secret = finding.get('Secret', '')
if len(secret) > 8:
masked = secret[:4] + '*' * (len(secret) - 8) + secret[-4:]
else:
masked = '*' * len(secret)
output_parts.append(f" 密钥: {masked}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"findings_count": len(findings),
"findings": [
{"rule": f.get("RuleID"), "file": f.get("File"), "line": f.get("StartLine")}
for f in findings[:10]
]
}
)
except asyncio.TimeoutError:
return ToolResult(success=False, error="Gitleaks 扫描超时")
except Exception as e:
return ToolResult(success=False, error=f"Gitleaks 执行错误: {str(e)}")
async def _check_gitleaks(self) -> bool:
try:
proc = await asyncio.create_subprocess_exec(
"gitleaks", "version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
return proc.returncode == 0
except:
return False
# ============ npm audit 工具 ============
class NpmAuditInput(BaseModel):
"""npm audit 扫描输入"""
target_path: str = Field(default=".", description="包含 package.json 的目录")
production_only: bool = Field(default=False, description="仅扫描生产依赖")
class NpmAuditTool(AgentTool):
"""
npm audit 依赖漏洞扫描工具
扫描 Node.js 项目的依赖漏洞基于 npm 官方漏洞数据库
"""
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "npm_audit"
@property
def description(self) -> str:
return """使用 npm audit 扫描 Node.js 项目的依赖漏洞。
基于 npm 官方漏洞数据库检测已知的依赖安全问题
适用于:
- 包含 package.json Node.js 项目
- 前端项目 (React, Vue, Angular )
需要先运行 npm install 安装依赖"""
@property
def args_schema(self):
return NpmAuditInput
async def _execute(
self,
target_path: str = ".",
production_only: bool = False,
**kwargs
) -> ToolResult:
"""执行 npm audit"""
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
# 检查 package.json
package_json = os.path.join(full_path, "package.json")
if not os.path.exists(package_json):
return ToolResult(
success=False,
error=f"未找到 package.json: {target_path}",
)
cmd = ["npm", "audit", "--json"]
if production_only:
cmd.append("--production")
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=full_path,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
try:
results = json.loads(stdout.decode())
except json.JSONDecodeError:
return ToolResult(success=True, data="npm audit 输出为空或格式错误")
vulnerabilities = results.get("vulnerabilities", {})
if not vulnerabilities:
return ToolResult(
success=True,
data="📦 npm audit 完成,未发现依赖漏洞",
metadata={"findings_count": 0}
)
output_parts = ["📦 npm audit 依赖漏洞扫描结果\n"]
severity_counts = {"critical": 0, "high": 0, "moderate": 0, "low": 0}
for name, vuln in vulnerabilities.items():
severity = vuln.get("severity", "low")
severity_counts[severity] = severity_counts.get(severity, 0) + 1
output_parts.append(f"漏洞统计: 🔴 Critical: {severity_counts['critical']}, 🟠 High: {severity_counts['high']}, 🟡 Moderate: {severity_counts['moderate']}, 🟢 Low: {severity_counts['low']}\n")
severity_icons = {"critical": "🔴", "high": "🟠", "moderate": "🟡", "low": "🟢"}
for name, vuln in list(vulnerabilities.items())[:20]:
sev = vuln.get("severity", "low")
icon = severity_icons.get(sev, "")
output_parts.append(f"\n{icon} [{sev.upper()}] {name}")
output_parts.append(f" 版本范围: {vuln.get('range', 'unknown')}")
via = vuln.get("via", [])
if via and isinstance(via[0], dict):
output_parts.append(f" 来源: {via[0].get('title', '')[:100]}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"findings_count": len(vulnerabilities),
"severity_counts": severity_counts,
}
)
except asyncio.TimeoutError:
return ToolResult(success=False, error="npm audit 超时")
except Exception as e:
return ToolResult(success=False, error=f"npm audit 错误: {str(e)}")
# ============ Safety 工具 (Python 依赖) ============
class SafetyInput(BaseModel):
"""Safety 扫描输入"""
requirements_file: str = Field(default="requirements.txt", description="requirements 文件路径")
class SafetyTool(AgentTool):
"""
Safety Python 依赖漏洞扫描工具
检查 Python 依赖中的已知安全漏洞
"""
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "safety_scan"
@property
def description(self) -> str:
return """使用 Safety 扫描 Python 依赖的安全漏洞。
基于 PyUp.io 漏洞数据库检测已知的依赖安全问题
适用于:
- 包含 requirements.txt Python 项目
- Pipenv 项目 (Pipfile.lock)
- Poetry 项目 (poetry.lock)"""
@property
def args_schema(self):
return SafetyInput
async def _execute(
self,
requirements_file: str = "requirements.txt",
**kwargs
) -> ToolResult:
"""执行 Safety 扫描"""
full_path = os.path.join(self.project_root, requirements_file)
if not os.path.exists(full_path):
return ToolResult(success=False, error=f"未找到依赖文件: {requirements_file}")
try:
proc = await asyncio.create_subprocess_exec(
"safety", "check", "-r", full_path, "--json",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=60)
try:
# Safety 输出的 JSON 格式可能不同版本有差异
output = stdout.decode()
if "No known security" in output:
return ToolResult(
success=True,
data="🐍 Safety 扫描完成,未发现 Python 依赖漏洞",
metadata={"findings_count": 0}
)
results = json.loads(output)
except:
return ToolResult(success=True, data=f"Safety 输出:\n{stdout.decode()[:1000]}")
vulnerabilities = results if isinstance(results, list) else results.get("vulnerabilities", [])
if not vulnerabilities:
return ToolResult(
success=True,
data="🐍 Safety 扫描完成,未发现 Python 依赖漏洞",
metadata={"findings_count": 0}
)
output_parts = ["🐍 Safety Python 依赖漏洞扫描结果\n"]
output_parts.append(f"发现 {len(vulnerabilities)} 个漏洞:\n")
for vuln in vulnerabilities[:20]:
if isinstance(vuln, list) and len(vuln) >= 4:
output_parts.append(f"\n🔴 {vuln[0]} ({vuln[1]})")
output_parts.append(f" 漏洞 ID: {vuln[4] if len(vuln) > 4 else 'N/A'}")
output_parts.append(f" 描述: {vuln[3][:200] if len(vuln) > 3 else ''}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={"findings_count": len(vulnerabilities)}
)
except Exception as e:
return ToolResult(success=False, error=f"Safety 执行错误: {str(e)}")
# ============ TruffleHog 工具 ============
class TruffleHogInput(BaseModel):
"""TruffleHog 扫描输入"""
target_path: str = Field(default=".", description="要扫描的目录")
only_verified: bool = Field(default=False, description="仅显示已验证的密钥")
class TruffleHogTool(AgentTool):
"""
TruffleHog 深度密钥扫描工具
TruffleHog 可以检测代码和 Git 历史中的密钥泄露
并可以验证密钥是否仍然有效
"""
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "trufflehog_scan"
@property
def description(self) -> str:
return """使用 TruffleHog 进行深度密钥扫描。
TruffleHog 可以扫描代码和 Git 历史并验证密钥是否有效
特点:
- 支持 700+ 种密钥类型
- 可以验证密钥是否仍然有效
- 扫描 Git 历史记录
- 高精度低误报
建议与 Gitleaks 配合使用以获得最佳效果"""
@property
def args_schema(self):
return TruffleHogInput
async def _execute(
self,
target_path: str = ".",
only_verified: bool = False,
**kwargs
) -> ToolResult:
"""执行 TruffleHog 扫描"""
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
cmd = ["trufflehog", "filesystem", full_path, "--json"]
if only_verified:
cmd.append("--only-verified")
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=180)
if not stdout.strip():
return ToolResult(
success=True,
data="🔍 TruffleHog 扫描完成,未发现密钥泄露",
metadata={"findings_count": 0}
)
# TruffleHog 输出每行一个 JSON 对象
findings = []
for line in stdout.decode().strip().split('\n'):
if line.strip():
try:
findings.append(json.loads(line))
except:
pass
if not findings:
return ToolResult(
success=True,
data="🔍 TruffleHog 扫描完成,未发现密钥泄露",
metadata={"findings_count": 0}
)
output_parts = ["🔍 TruffleHog 密钥扫描结果\n"]
output_parts.append(f"⚠️ 发现 {len(findings)} 处密钥泄露!\n")
for i, finding in enumerate(findings[:20]):
verified = "✅ 已验证有效" if finding.get("Verified") else "⚠️ 未验证"
output_parts.append(f"\n🔴 [{i+1}] {finding.get('DetectorName', 'unknown')} - {verified}")
output_parts.append(f" 文件: {finding.get('SourceMetadata', {}).get('Data', {}).get('Filesystem', {}).get('file', '')}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={"findings_count": len(findings)}
)
except asyncio.TimeoutError:
return ToolResult(success=False, error="TruffleHog 扫描超时")
except Exception as e:
return ToolResult(success=False, error=f"TruffleHog 执行错误: {str(e)}")
# ============ OSV-Scanner 工具 ============
class OSVScannerInput(BaseModel):
"""OSV-Scanner 扫描输入"""
target_path: str = Field(default=".", description="要扫描的项目目录")
class OSVScannerTool(AgentTool):
"""
OSV-Scanner 开源漏洞扫描工具
Google 开源的漏洞扫描工具使用 OSV 数据库
支持多种包管理器和锁文件
"""
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "osv_scan"
@property
def description(self) -> str:
return """使用 OSV-Scanner 扫描开源依赖漏洞。
Google 开源的漏洞扫描工具使用 OSV (Open Source Vulnerabilities) 数据库
支持:
- package.json / package-lock.json (npm)
- requirements.txt / Pipfile.lock (Python)
- go.mod / go.sum (Go)
- Cargo.lock (Rust)
- pom.xml (Maven)
- Gemfile.lock (Ruby)
- composer.lock (PHP)
特点:
- 覆盖多种语言和包管理器
- 使用 Google 维护的漏洞数据库
- 快速准确"""
@property
def args_schema(self):
return OSVScannerInput
async def _execute(
self,
target_path: str = ".",
**kwargs
) -> ToolResult:
"""执行 OSV-Scanner"""
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
try:
proc = await asyncio.create_subprocess_exec(
"osv-scanner", "--json", "-r", full_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
try:
results = json.loads(stdout.decode())
except:
if "no package sources found" in stdout.decode().lower():
return ToolResult(success=True, data="OSV-Scanner: 未找到可扫描的包文件")
return ToolResult(success=True, data=f"OSV-Scanner 输出:\n{stdout.decode()[:1000]}")
vulns = results.get("results", [])
if not vulns:
return ToolResult(
success=True,
data="📋 OSV-Scanner 扫描完成,未发现依赖漏洞",
metadata={"findings_count": 0}
)
total_vulns = sum(len(r.get("vulnerabilities", [])) for r in vulns)
output_parts = ["📋 OSV-Scanner 开源漏洞扫描结果\n"]
output_parts.append(f"发现 {total_vulns} 个漏洞:\n")
for result in vulns[:10]:
source = result.get("source", {}).get("path", "unknown")
for vuln in result.get("vulnerabilities", [])[:5]:
vuln_id = vuln.get("id", "")
summary = vuln.get("summary", "")[:100]
output_parts.append(f"\n🔴 {vuln_id}")
output_parts.append(f" 来源: {source}")
output_parts.append(f" 描述: {summary}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={"findings_count": total_vulns}
)
except Exception as e:
return ToolResult(success=False, error=f"OSV-Scanner 执行错误: {str(e)}")
# ============ 导出所有工具 ============
__all__ = [
"SemgrepTool",
"BanditTool",
"GitleaksTool",
"NpmAuditTool",
"SafetyTool",
"TruffleHogTool",
"OSVScannerTool",
]

View File

@ -0,0 +1,481 @@
"""
文件操作工具
读取和搜索代码文件
"""
import os
import re
import fnmatch
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
class FileReadInput(BaseModel):
"""文件读取输入"""
file_path: str = Field(description="文件路径(相对于项目根目录)")
start_line: Optional[int] = Field(default=None, description="起始行号从1开始")
end_line: Optional[int] = Field(default=None, description="结束行号")
max_lines: int = Field(default=500, description="最大返回行数")
class FileReadTool(AgentTool):
"""
文件读取工具
读取项目中的文件内容
"""
def __init__(self, project_root: str):
"""
初始化文件读取工具
Args:
project_root: 项目根目录
"""
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "read_file"
@property
def description(self) -> str:
return """读取项目中的文件内容。
使用场景:
- 查看完整的源代码文件
- 查看特定行范围的代码
- 获取配置文件内容
输入:
- file_path: 文件路径相对于项目根目录
- start_line: 可选起始行号
- end_line: 可选结束行号
- max_lines: 最大返回行数默认500
注意: 为避免输出过长建议指定行范围或使用 RAG 搜索定位代码"""
@property
def args_schema(self):
return FileReadInput
async def _execute(
self,
file_path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: int = 500,
**kwargs
) -> ToolResult:
"""执行文件读取"""
try:
# 安全检查:防止路径遍历
full_path = os.path.normpath(os.path.join(self.project_root, file_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的文件",
)
if not os.path.exists(full_path):
return ToolResult(
success=False,
error=f"文件不存在: {file_path}",
)
if not os.path.isfile(full_path):
return ToolResult(
success=False,
error=f"不是文件: {file_path}",
)
# 检查文件大小
file_size = os.path.getsize(full_path)
if file_size > 1024 * 1024: # 1MB
return ToolResult(
success=False,
error=f"文件过大 ({file_size / 1024:.1f}KB),请指定行范围",
)
# 读取文件
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
total_lines = len(lines)
# 处理行范围
if start_line is not None:
start_idx = max(0, start_line - 1)
else:
start_idx = 0
if end_line is not None:
end_idx = min(total_lines, end_line)
else:
end_idx = min(total_lines, start_idx + max_lines)
# 截取指定行
selected_lines = lines[start_idx:end_idx]
# 添加行号
numbered_lines = []
for i, line in enumerate(selected_lines, start=start_idx + 1):
numbered_lines.append(f"{i:4d}| {line.rstrip()}")
content = '\n'.join(numbered_lines)
# 检测语言
ext = os.path.splitext(file_path)[1].lower()
language = {
".py": "python", ".js": "javascript", ".ts": "typescript",
".java": "java", ".go": "go", ".rs": "rust",
".cpp": "cpp", ".c": "c", ".cs": "csharp",
".php": "php", ".rb": "ruby", ".swift": "swift",
}.get(ext, "text")
output = f"📄 文件: {file_path}\n"
output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n"
output += f"```{language}\n{content}\n```"
if end_idx < total_lines:
output += f"\n\n... 还有 {total_lines - end_idx} 行未显示"
return ToolResult(
success=True,
data=output,
metadata={
"file_path": file_path,
"total_lines": total_lines,
"start_line": start_idx + 1,
"end_line": end_idx,
"language": language,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"读取文件失败: {str(e)}",
)
class FileSearchInput(BaseModel):
"""文件搜索输入"""
keyword: str = Field(description="搜索关键字或正则表达式")
file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js")
directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)")
case_sensitive: bool = Field(default=False, description="是否区分大小写")
max_results: int = Field(default=50, description="最大结果数")
is_regex: bool = Field(default=False, description="是否使用正则表达式")
class FileSearchTool(AgentTool):
"""
文件搜索工具
在项目中搜索包含特定内容的代码
"""
# 排除的目录
EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "venv", "env",
}
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "search_code"
@property
def description(self) -> str:
return """在项目代码中搜索关键字或模式。
使用场景:
- 查找特定函数的所有调用位置
- 搜索特定的 API 使用
- 查找包含特定模式的代码
输入:
- keyword: 搜索关键字或正则表达式
- file_pattern: 可选文件名模式 *.py
- directory: 可选搜索目录
- case_sensitive: 是否区分大小写
- is_regex: 是否使用正则表达式
这是一个快速搜索工具结果包含匹配行和上下文"""
@property
def args_schema(self):
return FileSearchInput
async def _execute(
self,
keyword: str,
file_pattern: Optional[str] = None,
directory: Optional[str] = None,
case_sensitive: bool = False,
max_results: int = 50,
is_regex: bool = False,
**kwargs
) -> ToolResult:
"""执行文件搜索"""
try:
# 确定搜索目录
if directory:
search_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not search_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许搜索项目目录外的内容",
)
else:
search_dir = self.project_root
# 编译搜索模式
flags = 0 if case_sensitive else re.IGNORECASE
try:
if is_regex:
pattern = re.compile(keyword, flags)
else:
pattern = re.compile(re.escape(keyword), flags)
except re.error as e:
return ToolResult(
success=False,
error=f"无效的搜索模式: {e}",
)
results = []
files_searched = 0
# 遍历文件
for root, dirs, files in os.walk(search_dir):
# 排除目录
dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS]
for filename in files:
# 检查文件模式
if file_pattern and not fnmatch.fnmatch(filename, file_pattern):
continue
file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, self.project_root)
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
files_searched += 1
for i, line in enumerate(lines):
if pattern.search(line):
# 获取上下文
start = max(0, i - 1)
end = min(len(lines), i + 2)
context_lines = []
for j in range(start, end):
prefix = ">" if j == i else " "
context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}")
results.append({
"file": relative_path,
"line": i + 1,
"match": line.strip()[:200],
"context": '\n'.join(context_lines),
})
if len(results) >= max_results:
break
if len(results) >= max_results:
break
except Exception:
continue
if len(results) >= max_results:
break
if not results:
return ToolResult(
success=True,
data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件",
metadata={"files_searched": files_searched, "matches": 0}
)
# 格式化输出
output_parts = [f"🔍 搜索结果: '{keyword}'\n"]
output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n")
for result in results:
output_parts.append(f"\n📄 {result['file']}:{result['line']}")
output_parts.append(f"```\n{result['context']}\n```")
if len(results) >= max_results:
output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"keyword": keyword,
"files_searched": files_searched,
"matches": len(results),
"results": results[:10], # 只在元数据中保留前10个
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"搜索失败: {str(e)}",
)
class ListFilesInput(BaseModel):
"""列出文件输入"""
directory: str = Field(default=".", description="目录路径(相对于项目根目录)")
pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py")
recursive: bool = Field(default=False, description="是否递归列出子目录")
max_files: int = Field(default=100, description="最大文件数")
class ListFilesTool(AgentTool):
"""
列出文件工具
列出目录中的文件
"""
EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage",
}
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "list_files"
@property
def description(self) -> str:
return """列出目录中的文件。
使用场景:
- 了解项目结构
- 查找特定类型的文件
- 浏览目录内容
输入:
- directory: 目录路径
- pattern: 可选文件名模式
- recursive: 是否递归
- max_files: 最大文件数"""
@property
def args_schema(self):
return ListFilesInput
async def _execute(
self,
directory: str = ".",
pattern: Optional[str] = None,
recursive: bool = False,
max_files: int = 100,
**kwargs
) -> ToolResult:
"""执行文件列表"""
try:
target_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not target_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的目录",
)
if not os.path.exists(target_dir):
return ToolResult(
success=False,
error=f"目录不存在: {directory}",
)
files = []
dirs = []
if recursive:
for root, dirnames, filenames in os.walk(target_dir):
# 排除目录
dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS]
for filename in filenames:
if pattern and not fnmatch.fnmatch(filename, pattern):
continue
full_path = os.path.join(root, filename)
relative_path = os.path.relpath(full_path, self.project_root)
files.append(relative_path)
if len(files) >= max_files:
break
if len(files) >= max_files:
break
else:
for item in os.listdir(target_dir):
if item in self.EXCLUDE_DIRS:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
files.append(relative_path)
if len(files) >= max_files:
break
# 格式化输出
output_parts = [f"📁 目录: {directory}\n"]
if dirs:
output_parts.append("目录:")
for d in sorted(dirs)[:20]:
output_parts.append(f" 📂 {d}")
if len(dirs) > 20:
output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录")
if files:
output_parts.append(f"\n文件 ({len(files)}):")
for f in sorted(files):
output_parts.append(f" 📄 {f}")
if len(files) >= max_files:
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"directory": directory,
"file_count": len(files),
"dir_count": len(dirs),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"列出文件失败: {str(e)}",
)

View File

@ -0,0 +1,418 @@
"""
模式匹配工具
快速扫描代码中的危险模式
"""
import re
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from dataclasses import dataclass
from .base import AgentTool, ToolResult
@dataclass
class PatternMatch:
"""模式匹配结果"""
pattern_name: str
pattern_type: str
file_path: str
line_number: int
matched_text: str
context: str
severity: str
description: str
class PatternMatchInput(BaseModel):
"""模式匹配输入"""
code: str = Field(description="要扫描的代码内容")
file_path: str = Field(default="unknown", description="文件路径")
pattern_types: Optional[List[str]] = Field(
default=None,
description="要检测的漏洞类型列表,如 ['sql_injection', 'xss']。为空则检测所有类型"
)
language: Optional[str] = Field(default=None, description="编程语言,用于选择特定模式")
class PatternMatchTool(AgentTool):
"""
模式匹配工具
使用正则表达式快速扫描代码中的危险模式
"""
# 危险模式定义
PATTERNS: Dict[str, Dict[str, Any]] = {
# SQL 注入模式
"sql_injection": {
"patterns": {
"python": [
(r'cursor\.execute\s*\(\s*["\'].*%[sd].*["\'].*%', "格式化字符串构造SQL"),
(r'cursor\.execute\s*\(\s*f["\']', "f-string构造SQL"),
(r'cursor\.execute\s*\([^,)]+\+', "字符串拼接构造SQL"),
(r'\.execute\s*\(\s*["\'][^"\']*\{', "format()构造SQL"),
(r'text\s*\(\s*["\'].*\+.*["\']', "SQLAlchemy text()拼接"),
],
"javascript": [
(r'\.query\s*\(\s*[`"\'].*\$\{', "模板字符串构造SQL"),
(r'\.query\s*\(\s*["\'].*\+', "字符串拼接构造SQL"),
(r'mysql\.query\s*\([^,)]+\+', "MySQL查询拼接"),
],
"java": [
(r'Statement.*execute.*\+', "Statement字符串拼接"),
(r'createQuery\s*\([^,)]+\+', "JPA查询拼接"),
(r'\.executeQuery\s*\([^,)]+\+', "executeQuery拼接"),
],
"php": [
(r'mysql_query\s*\(\s*["\'].*\.\s*\$', "mysql_query拼接"),
(r'mysqli_query\s*\([^,]+,\s*["\'].*\.\s*\$', "mysqli_query拼接"),
(r'\$pdo->query\s*\(\s*["\'].*\.\s*\$', "PDO query拼接"),
],
"go": [
(r'\.Query\s*\([^,)]+\+', "Query字符串拼接"),
(r'\.Exec\s*\([^,)]+\+', "Exec字符串拼接"),
(r'fmt\.Sprintf\s*\([^)]+\)\s*\)', "Sprintf构造SQL"),
],
},
"severity": "high",
"description": "SQL注入漏洞用户输入直接拼接到SQL语句中",
},
# XSS 模式
"xss": {
"patterns": {
"javascript": [
(r'innerHTML\s*=\s*[^;]+', "innerHTML赋值"),
(r'outerHTML\s*=\s*[^;]+', "outerHTML赋值"),
(r'document\.write\s*\(', "document.write"),
(r'\.html\s*\([^)]+\)', "jQuery html()"),
(r'dangerouslySetInnerHTML', "React dangerouslySetInnerHTML"),
],
"python": [
(r'\|\s*safe\b', "Django safe过滤器"),
(r'Markup\s*\(', "Flask Markup"),
(r'mark_safe\s*\(', "Django mark_safe"),
],
"php": [
(r'echo\s+\$_(?:GET|POST|REQUEST)', "直接输出用户输入"),
(r'print\s+\$_(?:GET|POST|REQUEST)', "打印用户输入"),
],
"java": [
(r'out\.print(?:ln)?\s*\([^)]*request\.getParameter', "直接输出请求参数"),
],
},
"severity": "high",
"description": "XSS跨站脚本漏洞未转义的用户输入被渲染到页面",
},
# 命令注入模式
"command_injection": {
"patterns": {
"python": [
(r'os\.system\s*\([^)]*\+', "os.system拼接"),
(r'os\.system\s*\([^)]*%', "os.system格式化"),
(r'os\.system\s*\(\s*f["\']', "os.system f-string"),
(r'subprocess\.(?:call|run|Popen)\s*\([^)]*shell\s*=\s*True', "shell=True"),
(r'subprocess\.(?:call|run|Popen)\s*\(\s*["\'][^"\']+%', "subprocess格式化"),
(r'eval\s*\(', "eval()"),
(r'exec\s*\(', "exec()"),
],
"javascript": [
(r'exec\s*\([^)]+\+', "exec拼接"),
(r'spawn\s*\([^)]+,\s*\{[^}]*shell:\s*true', "spawn shell"),
(r'eval\s*\(', "eval()"),
(r'Function\s*\(', "Function构造器"),
],
"php": [
(r'exec\s*\(\s*\$', "exec变量"),
(r'system\s*\(\s*\$', "system变量"),
(r'passthru\s*\(\s*\$', "passthru变量"),
(r'shell_exec\s*\(\s*\$', "shell_exec变量"),
(r'`[^`]*\$[^`]*`', "反引号命令执行"),
],
"java": [
(r'Runtime\.getRuntime\(\)\.exec\s*\([^)]+\+', "Runtime.exec拼接"),
(r'ProcessBuilder[^;]+\+', "ProcessBuilder拼接"),
],
"go": [
(r'exec\.Command\s*\([^)]+\+', "exec.Command拼接"),
],
},
"severity": "critical",
"description": "命令注入漏洞:用户输入被用于执行系统命令",
},
# 路径遍历模式
"path_traversal": {
"patterns": {
"python": [
(r'open\s*\([^)]*\+', "open()拼接"),
(r'open\s*\([^)]*%', "open()格式化"),
(r'os\.path\.join\s*\([^)]*request', "join用户输入"),
(r'send_file\s*\([^)]*request', "send_file用户输入"),
],
"javascript": [
(r'fs\.read(?:File|FileSync)\s*\([^)]+\+', "readFile拼接"),
(r'path\.join\s*\([^)]*req\.', "path.join用户输入"),
(r'res\.sendFile\s*\([^)]+\+', "sendFile拼接"),
],
"php": [
(r'include\s*\(\s*\$', "include变量"),
(r'require\s*\(\s*\$', "require变量"),
(r'file_get_contents\s*\(\s*\$', "file_get_contents变量"),
(r'fopen\s*\(\s*\$', "fopen变量"),
],
"java": [
(r'new\s+File\s*\([^)]+request\.getParameter', "File构造用户输入"),
(r'new\s+FileInputStream\s*\([^)]+\+', "FileInputStream拼接"),
],
},
"severity": "high",
"description": "路径遍历漏洞:用户可以访问任意文件",
},
# SSRF 模式
"ssrf": {
"patterns": {
"python": [
(r'requests\.(?:get|post|put|delete)\s*\([^)]*request\.', "requests用户URL"),
(r'urllib\.request\.urlopen\s*\([^)]*request\.', "urlopen用户URL"),
(r'httpx\.(?:get|post)\s*\([^)]*request\.', "httpx用户URL"),
],
"javascript": [
(r'fetch\s*\([^)]*req\.', "fetch用户URL"),
(r'axios\.(?:get|post)\s*\([^)]*req\.', "axios用户URL"),
(r'http\.request\s*\([^)]*req\.', "http.request用户URL"),
],
"java": [
(r'new\s+URL\s*\([^)]*request\.getParameter', "URL构造用户输入"),
(r'HttpClient[^;]+request\.getParameter', "HttpClient用户URL"),
],
"php": [
(r'curl_setopt[^;]+CURLOPT_URL[^;]+\$', "curl用户URL"),
(r'file_get_contents\s*\(\s*\$_', "file_get_contents用户URL"),
],
},
"severity": "high",
"description": "SSRF漏洞服务端请求用户控制的URL",
},
# 不安全的反序列化
"deserialization": {
"patterns": {
"python": [
(r'pickle\.loads?\s*\(', "pickle反序列化"),
(r'yaml\.load\s*\([^)]*(?!Loader)', "yaml.load无安全Loader"),
(r'yaml\.unsafe_load\s*\(', "yaml.unsafe_load"),
(r'marshal\.loads?\s*\(', "marshal反序列化"),
],
"javascript": [
(r'serialize\s*\(', "serialize"),
(r'unserialize\s*\(', "unserialize"),
],
"java": [
(r'ObjectInputStream\s*\(', "ObjectInputStream"),
(r'XMLDecoder\s*\(', "XMLDecoder"),
(r'readObject\s*\(', "readObject"),
],
"php": [
(r'unserialize\s*\(\s*\$', "unserialize用户输入"),
],
},
"severity": "critical",
"description": "不安全的反序列化:可能导致远程代码执行",
},
# 硬编码密钥
"hardcoded_secret": {
"patterns": {
"_common": [
(r'(?:password|passwd|pwd)\s*=\s*["\'][^"\']{4,}["\']', "硬编码密码"),
(r'(?:secret|api_?key|apikey|token|auth)\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"),
(r'(?:private_?key|priv_?key)\s*=\s*["\'][^"\']+["\']', "硬编码私钥"),
(r'-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----', "私钥"),
(r'(?:aws_?access_?key|aws_?secret)\s*=\s*["\'][^"\']+["\']', "AWS密钥"),
(r'(?:ghp_|gho_|github_pat_)[a-zA-Z0-9]{36,}', "GitHub Token"),
(r'sk-[a-zA-Z0-9]{48}', "OpenAI API Key"),
(r'(?:bearer|authorization)\s*[=:]\s*["\'][^"\']{20,}["\']', "Bearer Token"),
],
},
"severity": "medium",
"description": "硬编码密钥:敏感信息不应该硬编码在代码中",
},
# 弱加密
"weak_crypto": {
"patterns": {
"python": [
(r'hashlib\.md5\s*\(', "MD5哈希"),
(r'hashlib\.sha1\s*\(', "SHA1哈希"),
(r'DES\s*\(', "DES加密"),
(r'random\.random\s*\(', "不安全随机数"),
],
"javascript": [
(r'crypto\.createHash\s*\(\s*["\']md5["\']', "MD5哈希"),
(r'crypto\.createHash\s*\(\s*["\']sha1["\']', "SHA1哈希"),
(r'Math\.random\s*\(', "Math.random"),
],
"java": [
(r'MessageDigest\.getInstance\s*\(\s*["\']MD5["\']', "MD5哈希"),
(r'MessageDigest\.getInstance\s*\(\s*["\']SHA-?1["\']', "SHA1哈希"),
(r'DESKeySpec', "DES密钥"),
],
"php": [
(r'md5\s*\(', "MD5哈希"),
(r'sha1\s*\(', "SHA1哈希"),
(r'mcrypt_', "mcrypt已废弃"),
],
},
"severity": "low",
"description": "弱加密算法:使用了不安全的加密或哈希算法",
},
}
@property
def name(self) -> str:
return "pattern_match"
@property
def description(self) -> str:
vuln_types = ", ".join(self.PATTERNS.keys())
return f"""快速扫描代码中的危险模式和常见漏洞。
使用正则表达式检测已知的不安全代码模式
支持的漏洞类型: {vuln_types}
这是一个快速扫描工具可以在分析开始时使用来快速发现潜在问题
发现的问题需要进一步分析确认"""
@property
def args_schema(self):
return PatternMatchInput
async def _execute(
self,
code: str,
file_path: str = "unknown",
pattern_types: Optional[List[str]] = None,
language: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行模式匹配"""
matches: List[PatternMatch] = []
lines = code.split('\n')
# 确定要检查的漏洞类型
types_to_check = pattern_types or list(self.PATTERNS.keys())
# 自动检测语言
if not language:
language = self._detect_language(file_path)
for vuln_type in types_to_check:
if vuln_type not in self.PATTERNS:
continue
pattern_config = self.PATTERNS[vuln_type]
patterns_dict = pattern_config["patterns"]
# 获取语言特定模式和通用模式
patterns_to_use = []
if language and language in patterns_dict:
patterns_to_use.extend(patterns_dict[language])
if "_common" in patterns_dict:
patterns_to_use.extend(patterns_dict["_common"])
# 如果没有特定语言模式,尝试使用所有模式
if not patterns_to_use:
for lang, pats in patterns_dict.items():
if lang != "_common":
patterns_to_use.extend(pats)
# 执行匹配
for pattern, pattern_name in patterns_to_use:
try:
for i, line in enumerate(lines):
if re.search(pattern, line, re.IGNORECASE):
# 获取上下文
start = max(0, i - 2)
end = min(len(lines), i + 3)
context = '\n'.join(f"{j+1}: {lines[j]}" for j in range(start, end))
matches.append(PatternMatch(
pattern_name=pattern_name,
pattern_type=vuln_type,
file_path=file_path,
line_number=i + 1,
matched_text=line.strip()[:200],
context=context,
severity=pattern_config["severity"],
description=pattern_config["description"],
))
except re.error:
continue
if not matches:
return ToolResult(
success=True,
data="没有检测到已知的危险模式",
metadata={"patterns_checked": len(types_to_check), "matches": 0}
)
# 格式化输出
output_parts = [f"⚠️ 检测到 {len(matches)} 个潜在问题:\n"]
# 按严重程度排序
severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
matches.sort(key=lambda x: severity_order.get(x.severity, 4))
for match in matches:
severity_icon = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🟢"}.get(match.severity, "")
output_parts.append(f"\n{severity_icon} [{match.severity.upper()}] {match.pattern_type}")
output_parts.append(f" 位置: {match.file_path}:{match.line_number}")
output_parts.append(f" 模式: {match.pattern_name}")
output_parts.append(f" 描述: {match.description}")
output_parts.append(f" 匹配: {match.matched_text}")
output_parts.append(f" 上下文:\n{match.context}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"matches": len(matches),
"by_severity": {
s: len([m for m in matches if m.severity == s])
for s in ["critical", "high", "medium", "low"]
},
"details": [
{
"type": m.pattern_type,
"severity": m.severity,
"line": m.line_number,
"pattern": m.pattern_name,
}
for m in matches
]
}
)
def _detect_language(self, file_path: str) -> Optional[str]:
"""根据文件扩展名检测语言"""
ext_map = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "javascript",
".tsx": "javascript",
".java": "java",
".php": "php",
".go": "go",
".rb": "ruby",
}
for ext, lang in ext_map.items():
if file_path.lower().endswith(ext):
return lang
return None

View File

@ -0,0 +1,293 @@
"""
RAG 检索工具
支持语义检索代码
"""
from typing import Optional, List
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
from app.services.rag import CodeRetriever
class RAGQueryInput(BaseModel):
"""RAG 查询输入参数"""
query: str = Field(description="搜索查询,描述你要找的代码功能或特征")
top_k: int = Field(default=10, description="返回结果数量")
file_path: Optional[str] = Field(default=None, description="限定搜索的文件路径")
language: Optional[str] = Field(default=None, description="限定编程语言")
class RAGQueryTool(AgentTool):
"""
RAG 代码检索工具
使用语义搜索在代码库中查找相关代码
"""
def __init__(self, retriever: CodeRetriever):
super().__init__()
self.retriever = retriever
@property
def name(self) -> str:
return "rag_query"
@property
def description(self) -> str:
return """在代码库中进行语义搜索。
使用场景:
- 查找特定功能的实现代码
- 查找调用某个函数的代码
- 查找处理用户输入的代码
- 查找数据库操作相关代码
- 查找认证/授权相关代码
输入:
- query: 描述你要查找的代码例如 "处理用户登录的函数""SQL查询执行""文件上传处理"
- top_k: 返回结果数量默认10
- file_path: 可选限定在某个文件中搜索
- language: 可选限定编程语言
输出: 相关的代码片段列表包含文件路径行号代码内容和相似度分数"""
@property
def args_schema(self):
return RAGQueryInput
async def _execute(
self,
query: str,
top_k: int = 10,
file_path: Optional[str] = None,
language: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行 RAG 检索"""
try:
results = await self.retriever.retrieve(
query=query,
top_k=top_k,
filter_file_path=file_path,
filter_language=language,
)
if not results:
return ToolResult(
success=True,
data="没有找到相关代码",
metadata={"query": query, "results_count": 0}
)
# 格式化输出
output_parts = [f"找到 {len(results)} 个相关代码片段:\n"]
for i, result in enumerate(results):
output_parts.append(f"\n--- 结果 {i+1} (相似度: {result.score:.2f}) ---")
output_parts.append(f"文件: {result.file_path}")
output_parts.append(f"行号: {result.line_start}-{result.line_end}")
if result.name:
output_parts.append(f"名称: {result.name}")
if result.security_indicators:
output_parts.append(f"安全指标: {', '.join(result.security_indicators)}")
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"query": query,
"results_count": len(results),
"results": [r.to_dict() for r in results],
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"RAG 检索失败: {str(e)}",
)
class SecurityCodeSearchInput(BaseModel):
"""安全代码搜索输入"""
vulnerability_type: str = Field(
description="漏洞类型: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, auth_bypass, hardcoded_secret"
)
top_k: int = Field(default=20, description="返回结果数量")
class SecurityCodeSearchTool(AgentTool):
"""
安全相关代码搜索工具
专门用于查找可能存在安全漏洞的代码
"""
def __init__(self, retriever: CodeRetriever):
super().__init__()
self.retriever = retriever
@property
def name(self) -> str:
return "security_code_search"
@property
def description(self) -> str:
return """搜索可能存在安全漏洞的代码。
专门针对特定漏洞类型进行搜索
支持的漏洞类型:
- sql_injection: SQL 注入
- xss: 跨站脚本
- command_injection: 命令注入
- path_traversal: 路径遍历
- ssrf: 服务端请求伪造
- deserialization: 不安全的反序列化
- auth_bypass: 认证绕过
- hardcoded_secret: 硬编码密钥"""
@property
def args_schema(self):
return SecurityCodeSearchInput
async def _execute(
self,
vulnerability_type: str,
top_k: int = 20,
**kwargs
) -> ToolResult:
"""执行安全代码搜索"""
try:
results = await self.retriever.retrieve_security_related(
vulnerability_type=vulnerability_type,
top_k=top_k,
)
if not results:
return ToolResult(
success=True,
data=f"没有找到与 {vulnerability_type} 相关的代码",
metadata={"vulnerability_type": vulnerability_type, "results_count": 0}
)
# 格式化输出
output_parts = [f"找到 {len(results)} 个可能与 {vulnerability_type} 相关的代码:\n"]
for i, result in enumerate(results):
output_parts.append(f"\n--- 可疑代码 {i+1} ---")
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
if result.security_indicators:
output_parts.append(f"⚠️ 安全指标: {', '.join(result.security_indicators)}")
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"vulnerability_type": vulnerability_type,
"results_count": len(results),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"安全代码搜索失败: {str(e)}",
)
class FunctionContextInput(BaseModel):
"""函数上下文搜索输入"""
function_name: str = Field(description="函数名称")
file_path: Optional[str] = Field(default=None, description="文件路径")
include_callers: bool = Field(default=True, description="是否包含调用者")
include_callees: bool = Field(default=True, description="是否包含被调用的函数")
class FunctionContextTool(AgentTool):
"""
函数上下文搜索工具
查找函数的定义调用者和被调用者
"""
def __init__(self, retriever: CodeRetriever):
super().__init__()
self.retriever = retriever
@property
def name(self) -> str:
return "function_context"
@property
def description(self) -> str:
return """查找函数的上下文信息,包括定义、调用者和被调用的函数。
用于追踪数据流和理解函数的使用方式
输入:
- function_name: 要查找的函数名
- file_path: 可选限定文件路径
- include_callers: 是否查找调用此函数的代码
- include_callees: 是否查找此函数调用的其他函数"""
@property
def args_schema(self):
return FunctionContextInput
async def _execute(
self,
function_name: str,
file_path: Optional[str] = None,
include_callers: bool = True,
include_callees: bool = True,
**kwargs
) -> ToolResult:
"""执行函数上下文搜索"""
try:
context = await self.retriever.retrieve_function_context(
function_name=function_name,
file_path=file_path,
include_callers=include_callers,
include_callees=include_callees,
)
output_parts = [f"函数 '{function_name}' 的上下文分析:\n"]
# 函数定义
if context["definition"]:
output_parts.append("### 函数定义:")
for result in context["definition"]:
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
output_parts.append(f"```{result.language}\n{result.content}\n```")
else:
output_parts.append("未找到函数定义")
# 调用者
if context["callers"]:
output_parts.append(f"\n### 调用此函数的代码 ({len(context['callers'])} 处):")
for result in context["callers"][:5]:
output_parts.append(f"- {result.file_path}:{result.line_start}")
output_parts.append(f"```{result.language}\n{result.content[:500]}\n```")
# 被调用者
if context["callees"]:
output_parts.append(f"\n### 此函数调用的其他函数:")
for result in context["callees"][:5]:
if result.name:
output_parts.append(f"- {result.name} ({result.file_path})")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"function_name": function_name,
"definition_count": len(context["definition"]),
"callers_count": len(context["callers"]),
"callees_count": len(context["callees"]),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"函数上下文搜索失败: {str(e)}",
)

View File

@ -0,0 +1,647 @@
"""
沙箱执行工具
Docker 沙箱中执行代码和命令进行漏洞验证
"""
import asyncio
import json
import logging
import tempfile
import os
import shutil
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from dataclasses import dataclass
from .base import AgentTool, ToolResult
logger = logging.getLogger(__name__)
@dataclass
class SandboxConfig:
"""沙箱配置"""
image: str = "python:3.11-slim"
memory_limit: str = "512m"
cpu_limit: float = 1.0
timeout: int = 60
network_mode: str = "none" # none, bridge, host
read_only: bool = True
user: str = "1000:1000"
class SandboxManager:
"""
沙箱管理器
管理 Docker 容器的创建执行和清理
"""
def __init__(self, config: Optional[SandboxConfig] = None):
self.config = config or SandboxConfig()
self._docker_client = None
self._initialized = False
async def initialize(self):
"""初始化 Docker 客户端"""
if self._initialized:
return
try:
import docker
self._docker_client = docker.from_env()
# 测试连接
self._docker_client.ping()
self._initialized = True
logger.info("Docker sandbox manager initialized")
except Exception as e:
logger.warning(f"Docker not available: {e}")
self._docker_client = None
@property
def is_available(self) -> bool:
"""检查 Docker 是否可用"""
return self._docker_client is not None
async def execute_command(
self,
command: str,
working_dir: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""
在沙箱中执行命令
Args:
command: 要执行的命令
working_dir: 工作目录
env: 环境变量
timeout: 超时时间
Returns:
执行结果
"""
if not self.is_available:
return {
"success": False,
"error": "Docker 不可用",
"stdout": "",
"stderr": "",
"exit_code": -1,
}
timeout = timeout or self.config.timeout
try:
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
# 准备容器配置
container_config = {
"image": self.config.image,
"command": ["sh", "-c", command],
"detach": True,
"mem_limit": self.config.memory_limit,
"cpu_period": 100000,
"cpu_quota": int(100000 * self.config.cpu_limit),
"network_mode": self.config.network_mode,
"user": self.config.user,
"read_only": self.config.read_only,
"volumes": {
temp_dir: {"bind": "/workspace", "mode": "rw"},
},
"working_dir": working_dir or "/workspace",
"environment": env or {},
# 安全配置
"cap_drop": ["ALL"],
"security_opt": ["no-new-privileges:true"],
}
# 创建并启动容器
container = await asyncio.to_thread(
self._docker_client.containers.run,
**container_config
)
try:
# 等待执行完成
result = await asyncio.wait_for(
asyncio.to_thread(container.wait),
timeout=timeout
)
# 获取日志
stdout = await asyncio.to_thread(
container.logs, stdout=True, stderr=False
)
stderr = await asyncio.to_thread(
container.logs, stdout=False, stderr=True
)
return {
"success": result["StatusCode"] == 0,
"stdout": stdout.decode('utf-8', errors='ignore')[:10000],
"stderr": stderr.decode('utf-8', errors='ignore')[:2000],
"exit_code": result["StatusCode"],
"error": None,
}
except asyncio.TimeoutError:
await asyncio.to_thread(container.kill)
return {
"success": False,
"error": f"执行超时 ({timeout}秒)",
"stdout": "",
"stderr": "",
"exit_code": -1,
}
finally:
# 清理容器
await asyncio.to_thread(container.remove, force=True)
except Exception as e:
logger.error(f"Sandbox execution error: {e}")
return {
"success": False,
"error": str(e),
"stdout": "",
"stderr": "",
"exit_code": -1,
}
async def execute_python(
self,
code: str,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""
在沙箱中执行 Python 代码
Args:
code: Python 代码
timeout: 超时时间
Returns:
执行结果
"""
# 转义代码中的单引号
escaped_code = code.replace("'", "'\\''")
command = f"python3 -c '{escaped_code}'"
return await self.execute_command(command, timeout=timeout)
async def execute_http_request(
self,
method: str,
url: str,
headers: Optional[Dict[str, str]] = None,
data: Optional[str] = None,
timeout: int = 30,
) -> Dict[str, Any]:
"""
在沙箱中执行 HTTP 请求
Args:
method: HTTP 方法
url: URL
headers: 请求头
data: 请求体
timeout: 超时
Returns:
HTTP 响应
"""
# 构建 curl 命令
curl_parts = ["curl", "-s", "-S", "-w", "'\\n%{http_code}'", "-X", method]
if headers:
for key, value in headers.items():
curl_parts.extend(["-H", f"'{key}: {value}'"])
if data:
curl_parts.extend(["-d", f"'{data}'"])
curl_parts.append(f"'{url}'")
command = " ".join(curl_parts)
# 使用带网络的镜像
original_network = self.config.network_mode
self.config.network_mode = "bridge" # 允许网络访问
try:
result = await self.execute_command(command, timeout=timeout)
if result["success"] and result["stdout"]:
lines = result["stdout"].strip().split('\n')
if lines:
status_code = lines[-1].strip()
body = '\n'.join(lines[:-1])
return {
"success": True,
"status_code": int(status_code) if status_code.isdigit() else 0,
"body": body[:5000],
"error": None,
}
return {
"success": False,
"status_code": 0,
"body": "",
"error": result.get("error") or result.get("stderr"),
}
finally:
self.config.network_mode = original_network
async def verify_vulnerability(
self,
vulnerability_type: str,
target_url: str,
payload: str,
expected_pattern: Optional[str] = None,
) -> Dict[str, Any]:
"""
验证漏洞
Args:
vulnerability_type: 漏洞类型
target_url: 目标 URL
payload: 攻击载荷
expected_pattern: 期望在响应中匹配的模式
Returns:
验证结果
"""
verification_result = {
"vulnerability_type": vulnerability_type,
"target_url": target_url,
"payload": payload,
"is_vulnerable": False,
"evidence": None,
"error": None,
}
try:
# 发送请求
response = await self.execute_http_request(
method="GET" if "?" in target_url else "POST",
url=target_url,
data=payload if "?" not in target_url else None,
)
if not response["success"]:
verification_result["error"] = response.get("error")
return verification_result
body = response.get("body", "")
status_code = response.get("status_code", 0)
# 检查响应
if expected_pattern:
import re
if re.search(expected_pattern, body, re.IGNORECASE):
verification_result["is_vulnerable"] = True
verification_result["evidence"] = f"响应中包含预期模式: {expected_pattern}"
else:
# 根据漏洞类型进行通用检查
if vulnerability_type == "sql_injection":
error_patterns = [
r"SQL syntax",
r"mysql_fetch",
r"ORA-\d+",
r"PostgreSQL.*ERROR",
r"SQLite.*error",
r"ODBC.*Driver",
]
for pattern in error_patterns:
if re.search(pattern, body, re.IGNORECASE):
verification_result["is_vulnerable"] = True
verification_result["evidence"] = f"SQL错误信息: {pattern}"
break
elif vulnerability_type == "xss":
if payload in body:
verification_result["is_vulnerable"] = True
verification_result["evidence"] = "XSS payload 被反射到响应中"
elif vulnerability_type == "command_injection":
# 检查命令执行结果
if "uid=" in body or "root:" in body:
verification_result["is_vulnerable"] = True
verification_result["evidence"] = "命令执行成功"
verification_result["response_status"] = status_code
verification_result["response_length"] = len(body)
except Exception as e:
verification_result["error"] = str(e)
return verification_result
class SandboxCommandInput(BaseModel):
"""沙箱命令输入"""
command: str = Field(description="要执行的命令")
timeout: int = Field(default=30, description="超时时间(秒)")
class SandboxTool(AgentTool):
"""
沙箱执行工具
在安全隔离的环境中执行代码和命令
"""
# 允许的命令前缀
ALLOWED_COMMANDS = [
"python", "python3", "node", "curl", "wget",
"cat", "head", "tail", "grep", "find", "ls",
"echo", "printf", "test", "id", "whoami",
]
def __init__(self, sandbox_manager: Optional[SandboxManager] = None):
super().__init__()
self.sandbox_manager = sandbox_manager or SandboxManager()
@property
def name(self) -> str:
return "sandbox_exec"
@property
def description(self) -> str:
return """在安全沙箱中执行命令或代码。
用于验证漏洞测试 PoC 或执行安全检查
安全限制:
- 命令在 Docker 容器中执行
- 网络默认隔离
- 资源有限制
- 只允许特定命令
允许的命令: python, python3, node, curl, cat, grep, find, ls, echo, id
使用场景:
- 验证命令注入漏洞
- 执行 PoC 代码
- 测试 payload 效果"""
@property
def args_schema(self):
return SandboxCommandInput
async def _execute(
self,
command: str,
timeout: int = 30,
**kwargs
) -> ToolResult:
"""执行沙箱命令"""
# 初始化沙箱
await self.sandbox_manager.initialize()
if not self.sandbox_manager.is_available:
return ToolResult(
success=False,
error="沙箱环境不可用Docker 未安装或未运行)",
)
# 安全检查:验证命令是否允许
cmd_parts = command.strip().split()
if not cmd_parts:
return ToolResult(success=False, error="命令不能为空")
base_cmd = cmd_parts[0]
if not any(base_cmd.startswith(allowed) for allowed in self.ALLOWED_COMMANDS):
return ToolResult(
success=False,
error=f"命令 '{base_cmd}' 不在允许列表中。允许的命令: {', '.join(self.ALLOWED_COMMANDS)}",
)
# 执行命令
result = await self.sandbox_manager.execute_command(
command=command,
timeout=timeout,
)
# 格式化输出
output_parts = ["🐳 沙箱执行结果\n"]
output_parts.append(f"命令: {command}")
output_parts.append(f"退出码: {result['exit_code']}")
if result["stdout"]:
output_parts.append(f"\n标准输出:\n```\n{result['stdout']}\n```")
if result["stderr"]:
output_parts.append(f"\n标准错误:\n```\n{result['stderr']}\n```")
if result.get("error"):
output_parts.append(f"\n错误: {result['error']}")
return ToolResult(
success=result["success"],
data="\n".join(output_parts),
error=result.get("error"),
metadata={
"command": command,
"exit_code": result["exit_code"],
}
)
class HttpRequestInput(BaseModel):
"""HTTP 请求输入"""
method: str = Field(default="GET", description="HTTP 方法 (GET, POST, PUT, DELETE)")
url: str = Field(description="请求 URL")
headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
data: Optional[str] = Field(default=None, description="请求体")
timeout: int = Field(default=30, description="超时时间(秒)")
class SandboxHttpTool(AgentTool):
"""
沙箱 HTTP 请求工具
在沙箱中发送 HTTP 请求
"""
def __init__(self, sandbox_manager: Optional[SandboxManager] = None):
super().__init__()
self.sandbox_manager = sandbox_manager or SandboxManager()
@property
def name(self) -> str:
return "sandbox_http"
@property
def description(self) -> str:
return """在沙箱中发送 HTTP 请求。
用于测试 Web 漏洞如 SQL 注入XSSSSRF
输入:
- method: HTTP 方法
- url: 请求 URL
- headers: 可选请求头
- data: 可选请求体
- timeout: 超时时间
使用场景:
- 验证 SQL 注入漏洞
- 测试 XSS payload
- 验证 SSRF 漏洞
- 测试认证绕过"""
@property
def args_schema(self):
return HttpRequestInput
async def _execute(
self,
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
data: Optional[str] = None,
timeout: int = 30,
**kwargs
) -> ToolResult:
"""执行 HTTP 请求"""
await self.sandbox_manager.initialize()
if not self.sandbox_manager.is_available:
return ToolResult(
success=False,
error="沙箱环境不可用",
)
result = await self.sandbox_manager.execute_http_request(
method=method,
url=url,
headers=headers,
data=data,
timeout=timeout,
)
output_parts = ["🌐 HTTP 请求结果\n"]
output_parts.append(f"请求: {method} {url}")
if headers:
output_parts.append(f"请求头: {json.dumps(headers, ensure_ascii=False)}")
if data:
output_parts.append(f"请求体: {data[:500]}")
output_parts.append(f"\n状态码: {result.get('status_code', 'N/A')}")
if result.get("body"):
body = result["body"]
if len(body) > 2000:
body = body[:2000] + f"\n... (截断,共 {len(result['body'])} 字符)"
output_parts.append(f"\n响应内容:\n```\n{body}\n```")
if result.get("error"):
output_parts.append(f"\n错误: {result['error']}")
return ToolResult(
success=result["success"],
data="\n".join(output_parts),
error=result.get("error"),
metadata={
"method": method,
"url": url,
"status_code": result.get("status_code"),
"response_length": len(result.get("body", "")),
}
)
class VulnerabilityVerifyInput(BaseModel):
"""漏洞验证输入"""
vulnerability_type: str = Field(description="漏洞类型 (sql_injection, xss, command_injection, etc.)")
target_url: str = Field(description="目标 URL")
payload: str = Field(description="攻击载荷")
expected_pattern: Optional[str] = Field(default=None, description="期望在响应中匹配的正则模式")
class VulnerabilityVerifyTool(AgentTool):
"""
漏洞验证工具
在沙箱中验证漏洞是否真实存在
"""
def __init__(self, sandbox_manager: Optional[SandboxManager] = None):
super().__init__()
self.sandbox_manager = sandbox_manager or SandboxManager()
@property
def name(self) -> str:
return "verify_vulnerability"
@property
def description(self) -> str:
return """验证漏洞是否真实存在。
发送包含攻击载荷的请求分析响应判断漏洞是否可利用
输入:
- vulnerability_type: 漏洞类型
- target_url: 目标 URL
- payload: 攻击载荷
- expected_pattern: 可选期望在响应中匹配的模式
支持的漏洞类型:
- sql_injection: SQL 注入
- xss: 跨站脚本
- command_injection: 命令注入
- path_traversal: 路径遍历
- ssrf: 服务端请求伪造"""
@property
def args_schema(self):
return VulnerabilityVerifyInput
async def _execute(
self,
vulnerability_type: str,
target_url: str,
payload: str,
expected_pattern: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行漏洞验证"""
await self.sandbox_manager.initialize()
if not self.sandbox_manager.is_available:
return ToolResult(
success=False,
error="沙箱环境不可用",
)
result = await self.sandbox_manager.verify_vulnerability(
vulnerability_type=vulnerability_type,
target_url=target_url,
payload=payload,
expected_pattern=expected_pattern,
)
output_parts = ["🔍 漏洞验证结果\n"]
output_parts.append(f"漏洞类型: {vulnerability_type}")
output_parts.append(f"目标: {target_url}")
output_parts.append(f"Payload: {payload[:200]}")
if result["is_vulnerable"]:
output_parts.append(f"\n🔴 结果: 漏洞已确认!")
output_parts.append(f"证据: {result.get('evidence', 'N/A')}")
else:
output_parts.append(f"\n🟢 结果: 未能确认漏洞")
if result.get("error"):
output_parts.append(f"错误: {result['error']}")
if result.get("response_status"):
output_parts.append(f"\nHTTP 状态码: {result['response_status']}")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"vulnerability_type": vulnerability_type,
"is_vulnerable": result["is_vulnerable"],
"evidence": result.get("evidence"),
}
)

View File

@ -144,3 +144,4 @@ class BaiduAdapter(BaseLLMAdapter):
return True return True

View File

@ -82,3 +82,4 @@ class DoubaoAdapter(BaseLLMAdapter):
return True return True

View File

@ -85,3 +85,4 @@ class MinimaxAdapter(BaseLLMAdapter):
return True return True

View File

@ -133,3 +133,4 @@ class BaseLLMAdapter(ABC):
self._client = None self._client = None

View File

@ -119,3 +119,4 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = {
} }

View File

@ -0,0 +1,18 @@
"""
RAG (Retrieval-Augmented Generation) 系统
用于代码索引和语义检索
"""
from .splitter import CodeSplitter, CodeChunk
from .embeddings import EmbeddingService
from .indexer import CodeIndexer
from .retriever import CodeRetriever
__all__ = [
"CodeSplitter",
"CodeChunk",
"EmbeddingService",
"CodeIndexer",
"CodeRetriever",
]

View File

@ -0,0 +1,658 @@
"""
嵌入模型服务
支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina
"""
import asyncio
import hashlib
import logging
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
import httpx
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingResult:
"""嵌入结果"""
embedding: List[float]
tokens_used: int
model: str
class EmbeddingProvider(ABC):
"""嵌入提供商基类"""
@abstractmethod
async def embed_text(self, text: str) -> EmbeddingResult:
"""嵌入单个文本"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
"""批量嵌入文本"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""嵌入向量维度"""
pass
class OpenAIEmbedding(EmbeddingProvider):
"""OpenAI 嵌入服务"""
MODELS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
self.api_key = api_key or settings.LLM_API_KEY
self.base_url = base_url or "https://api.openai.com/v1"
self.model = model
self._dimension = self.MODELS.get(model, 1536)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": truncated_texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class AzureOpenAIEmbedding(EmbeddingProvider):
"""
Azure OpenAI 嵌入服务
使用最新 API 版本 2024-10-21 (GA)
端点格式: https://<resource>.openai.azure.com/openai/deployments/<deployment>/embeddings?api-version=2024-10-21
"""
MODELS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
# 最新的 GA API 版本
API_VERSION = "2024-10-21"
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
self.api_key = api_key
self.base_url = base_url or "https://your-resource.openai.azure.com"
self.model = model
self._dimension = self.MODELS.get(model, 1536)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"api-key": self.api_key,
"Content-Type": "application/json",
}
payload = {
"input": truncated_texts,
}
# Azure URL 格式 - 使用最新 API 版本
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class OllamaEmbedding(EmbeddingProvider):
"""
Ollama 本地嵌入服务
使用新的 /api/embed 端点 (2024年起):
- 支持批量嵌入
- 使用 'input' 参数支持字符串或字符串数组
"""
MODELS = {
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
"bge-m3": 1024,
"qwen3-embedding": 1024,
}
def __init__(
self,
base_url: Optional[str] = None,
model: str = "nomic-embed-text",
):
self.base_url = base_url or "http://localhost:11434"
self.model = model
self._dimension = self.MODELS.get(model, 768)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
# 新的 Ollama /api/embed 端点
url = f"{self.base_url.rstrip('/')}/api/embed"
payload = {
"model": self.model,
"input": texts, # 新 API 使用 'input' 参数,支持批量
}
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
embeddings = data.get("embeddings", [])
results = []
for i, embedding in enumerate(embeddings):
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[i]) // 4,
model=self.model,
))
return results
class CohereEmbedding(EmbeddingProvider):
"""
Cohere 嵌入服务
使用新的 v2 API (2024年起):
- 端点: https://api.cohere.com/v2/embed
- 使用 'inputs' 参数替代 'texts'
- 需要指定 'embedding_types'
"""
MODELS = {
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
"embed-v4.0": 1024, # 最新模型
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "embed-multilingual-v3.0",
):
self.api_key = api_key
# 新的 v2 API 端点
self.base_url = base_url or "https://api.cohere.com/v2"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# v2 API 参数格式
payload = {
"model": self.model,
"inputs": texts, # v2 使用 'inputs' 而非 'texts'
"input_type": "search_document",
"embedding_types": ["float"], # v2 需要指定嵌入类型
}
url = f"{self.base_url.rstrip('/')}/embed"
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
embeddings_data = data.get("embeddings", {})
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
for embedding in embeddings:
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
model=self.model,
))
return results
class HuggingFaceEmbedding(EmbeddingProvider):
"""
HuggingFace Inference Providers 嵌入服务
使用新的 Router 端点 (2025年起):
https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
"""
MODELS = {
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/all-mpnet-base-v2": 768,
"BAAI/bge-large-zh-v1.5": 1024,
"BAAI/bge-m3": 1024,
"BAAI/bge-small-en-v1.5": 384,
"BAAI/bge-base-en-v1.5": 768,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "BAAI/bge-m3",
):
self.api_key = api_key
# 新的 Router 端点
self.base_url = base_url or "https://router.huggingface.co"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# 新的 HuggingFace Router URL 格式
# https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
payload = {
"inputs": texts,
"options": {
"wait_for_model": True,
}
}
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
# HuggingFace 返回格式: [[embedding1], [embedding2], ...]
for embedding in data:
# 有时候返回的是嵌套的列表
if isinstance(embedding, list) and len(embedding) > 0:
if isinstance(embedding[0], list):
# 取平均或第一个
embedding = embedding[0]
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[len(results)]) // 4,
model=self.model,
))
return results
class JinaEmbedding(EmbeddingProvider):
"""Jina AI 嵌入服务"""
MODELS = {
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
"jina-embeddings-v2-small-en": 512,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "jina-embeddings-v2-base-code",
):
self.api_key = api_key
self.base_url = base_url or "https://api.jina.ai/v1"
self.model = model
self._dimension = self.MODELS.get(model, 768)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class EmbeddingService:
"""
嵌入服务
统一管理嵌入模型和缓存
支持的提供商:
- openai: OpenAI 官方
- azure: Azure OpenAI
- ollama: Ollama 本地
- cohere: Cohere
- huggingface: HuggingFace Inference API
- jina: Jina AI
"""
def __init__(
self,
provider: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
cache_enabled: bool = True,
):
"""
初始化嵌入服务
Args:
provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina)
model: 模型名称
api_key: API Key
base_url: API Base URL
cache_enabled: 是否启用缓存
"""
self.cache_enabled = cache_enabled
self._cache: Dict[str, List[float]] = {}
# 确定提供商
provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
# 创建提供商实例
self._provider = self._create_provider(
provider=provider,
model=model,
api_key=api_key,
base_url=base_url,
)
logger.info(f"Embedding service initialized with {provider}/{model}")
def _create_provider(
self,
provider: str,
model: str,
api_key: Optional[str],
base_url: Optional[str],
) -> EmbeddingProvider:
"""创建嵌入提供商实例"""
provider = provider.lower()
if provider == "ollama":
return OllamaEmbedding(base_url=base_url, model=model)
elif provider == "azure":
return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "cohere":
return CohereEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "huggingface":
return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "jina":
return JinaEmbedding(api_key=api_key, base_url=base_url, model=model)
else:
# 默认使用 OpenAI
return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
@property
def dimension(self) -> int:
"""嵌入向量维度"""
return self._provider.dimension
def _cache_key(self, text: str) -> str:
"""生成缓存键"""
return hashlib.sha256(text.encode()).hexdigest()[:32]
async def embed(self, text: str) -> List[float]:
"""
嵌入单个文本
Args:
text: 文本内容
Returns:
嵌入向量
"""
if not text or not text.strip():
return [0.0] * self.dimension
# 检查缓存
if self.cache_enabled:
cache_key = self._cache_key(text)
if cache_key in self._cache:
return self._cache[cache_key]
result = await self._provider.embed_text(text)
# 存入缓存
if self.cache_enabled:
self._cache[cache_key] = result.embedding
return result.embedding
async def embed_batch(
self,
texts: List[str],
batch_size: int = 100,
show_progress: bool = False,
) -> List[List[float]]:
"""
批量嵌入文本
Args:
texts: 文本列表
batch_size: 批次大小
show_progress: 是否显示进度
Returns:
嵌入向量列表
"""
if not texts:
return []
embeddings = []
uncached_indices = []
uncached_texts = []
# 检查缓存
for i, text in enumerate(texts):
if not text or not text.strip():
embeddings.append([0.0] * self.dimension)
continue
if self.cache_enabled:
cache_key = self._cache_key(text)
if cache_key in self._cache:
embeddings.append(self._cache[cache_key])
continue
embeddings.append(None) # 占位
uncached_indices.append(i)
uncached_texts.append(text)
# 批量处理未缓存的文本
if uncached_texts:
for i in range(0, len(uncached_texts), batch_size):
batch = uncached_texts[i:i + batch_size]
batch_indices = uncached_indices[i:i + batch_size]
try:
results = await self._provider.embed_texts(batch)
for idx, result in zip(batch_indices, results):
embeddings[idx] = result.embedding
# 存入缓存
if self.cache_enabled:
cache_key = self._cache_key(texts[idx])
self._cache[cache_key] = result.embedding
except Exception as e:
logger.error(f"Batch embedding error: {e}")
# 对失败的使用零向量
for idx in batch_indices:
if embeddings[idx] is None:
embeddings[idx] = [0.0] * self.dimension
# 添加小延迟避免限流
await asyncio.sleep(0.1)
# 确保没有 None
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
@property
def cache_size(self) -> int:
"""缓存大小"""
return len(self._cache)

View File

@ -0,0 +1,585 @@
"""
代码索引器
将代码分块并索引到向量数据库
"""
import os
import asyncio
import logging
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable
from pathlib import Path
from dataclasses import dataclass
import json
from .splitter import CodeSplitter, CodeChunk
from .embeddings import EmbeddingService
logger = logging.getLogger(__name__)
# 支持的文本文件扩展名
TEXT_EXTENSIONS = {
".py", ".js", ".ts", ".tsx", ".jsx", ".java", ".go", ".rs",
".cpp", ".c", ".h", ".cc", ".hh", ".cs", ".php", ".rb",
".kt", ".swift", ".sql", ".sh", ".json", ".yml", ".yaml",
".xml", ".html", ".css", ".vue", ".svelte", ".md",
}
# 排除的目录
EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "out", "bin", "obj",
"__MACOSX", ".next", ".nuxt", "venv", "env", ".env",
}
# 排除的文件
EXCLUDE_FILES = {
".DS_Store", "package-lock.json", "yarn.lock", "pnpm-lock.yaml",
"Cargo.lock", "poetry.lock", "composer.lock", "Gemfile.lock",
}
@dataclass
class IndexingProgress:
"""索引进度"""
total_files: int = 0
processed_files: int = 0
total_chunks: int = 0
indexed_chunks: int = 0
current_file: str = ""
errors: List[str] = None
def __post_init__(self):
if self.errors is None:
self.errors = []
@property
def progress_percentage(self) -> float:
if self.total_files == 0:
return 0.0
return (self.processed_files / self.total_files) * 100
@dataclass
class IndexingResult:
"""索引结果"""
success: bool
total_files: int
indexed_files: int
total_chunks: int
errors: List[str]
collection_name: str
class VectorStore:
"""向量存储抽象基类"""
async def initialize(self):
"""初始化存储"""
pass
async def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
documents: List[str],
metadatas: List[Dict[str, Any]],
):
"""添加文档"""
raise NotImplementedError
async def query(
self,
query_embedding: List[float],
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""查询"""
raise NotImplementedError
async def delete_collection(self):
"""删除集合"""
raise NotImplementedError
async def get_count(self) -> int:
"""获取文档数量"""
raise NotImplementedError
class ChromaVectorStore(VectorStore):
"""Chroma 向量存储"""
def __init__(
self,
collection_name: str,
persist_directory: Optional[str] = None,
):
self.collection_name = collection_name
self.persist_directory = persist_directory
self._client = None
self._collection = None
async def initialize(self):
"""初始化 Chroma"""
try:
import chromadb
from chromadb.config import Settings
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(anonymized_telemetry=False),
)
else:
self._client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"Chroma collection '{self.collection_name}' initialized")
except ImportError:
raise ImportError("chromadb is required. Install with: pip install chromadb")
async def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
documents: List[str],
metadatas: List[Dict[str, Any]],
):
"""添加文档到 Chroma"""
if not ids:
return
# Chroma 对元数据有限制,需要清理
cleaned_metadatas = []
for meta in metadatas:
cleaned = {}
for k, v in meta.items():
if isinstance(v, (str, int, float, bool)):
cleaned[k] = v
elif isinstance(v, list):
# 列表转为 JSON 字符串
cleaned[k] = json.dumps(v)
elif v is not None:
cleaned[k] = str(v)
cleaned_metadatas.append(cleaned)
# 分批添加Chroma 批次限制)
batch_size = 500
for i in range(0, len(ids), batch_size):
batch_ids = ids[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
batch_documents = documents[i:i + batch_size]
batch_metadatas = cleaned_metadatas[i:i + batch_size]
await asyncio.to_thread(
self._collection.add,
ids=batch_ids,
embeddings=batch_embeddings,
documents=batch_documents,
metadatas=batch_metadatas,
)
async def query(
self,
query_embedding: List[float],
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""查询 Chroma"""
result = await asyncio.to_thread(
self._collection.query,
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
include=["documents", "metadatas", "distances"],
)
return {
"ids": result["ids"][0] if result["ids"] else [],
"documents": result["documents"][0] if result["documents"] else [],
"metadatas": result["metadatas"][0] if result["metadatas"] else [],
"distances": result["distances"][0] if result["distances"] else [],
}
async def delete_collection(self):
"""删除集合"""
if self._client and self._collection:
await asyncio.to_thread(
self._client.delete_collection,
name=self.collection_name,
)
async def get_count(self) -> int:
"""获取文档数量"""
if self._collection:
return await asyncio.to_thread(self._collection.count)
return 0
class InMemoryVectorStore(VectorStore):
"""内存向量存储(用于测试或小项目)"""
def __init__(self, collection_name: str):
self.collection_name = collection_name
self._documents: Dict[str, Dict[str, Any]] = {}
async def initialize(self):
"""初始化"""
logger.info(f"InMemory vector store '{self.collection_name}' initialized")
async def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
documents: List[str],
metadatas: List[Dict[str, Any]],
):
"""添加文档"""
for id_, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
self._documents[id_] = {
"embedding": emb,
"document": doc,
"metadata": meta,
}
async def query(
self,
query_embedding: List[float],
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""查询(使用余弦相似度)"""
import math
def cosine_similarity(a: List[float], b: List[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
results = []
for id_, data in self._documents.items():
# 应用过滤条件
if where:
match = True
for k, v in where.items():
if data["metadata"].get(k) != v:
match = False
break
if not match:
continue
similarity = cosine_similarity(query_embedding, data["embedding"])
results.append({
"id": id_,
"document": data["document"],
"metadata": data["metadata"],
"distance": 1 - similarity, # 转换为距离
})
# 按距离排序
results.sort(key=lambda x: x["distance"])
results = results[:n_results]
return {
"ids": [r["id"] for r in results],
"documents": [r["document"] for r in results],
"metadatas": [r["metadata"] for r in results],
"distances": [r["distance"] for r in results],
}
async def delete_collection(self):
"""删除集合"""
self._documents.clear()
async def get_count(self) -> int:
"""获取文档数量"""
return len(self._documents)
class CodeIndexer:
"""
代码索引器
将代码文件分块嵌入并索引到向量数据库
"""
def __init__(
self,
collection_name: str,
embedding_service: Optional[EmbeddingService] = None,
vector_store: Optional[VectorStore] = None,
splitter: Optional[CodeSplitter] = None,
persist_directory: Optional[str] = None,
):
"""
初始化索引器
Args:
collection_name: 向量集合名称
embedding_service: 嵌入服务
vector_store: 向量存储
splitter: 代码分块器
persist_directory: 持久化目录
"""
self.collection_name = collection_name
self.embedding_service = embedding_service or EmbeddingService()
self.splitter = splitter or CodeSplitter()
# 创建向量存储
if vector_store:
self.vector_store = vector_store
else:
try:
self.vector_store = ChromaVectorStore(
collection_name=collection_name,
persist_directory=persist_directory,
)
except ImportError:
logger.warning("Chroma not available, using in-memory store")
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
self._initialized = False
async def initialize(self):
"""初始化索引器"""
if not self._initialized:
await self.vector_store.initialize()
self._initialized = True
async def index_directory(
self,
directory: str,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
) -> AsyncGenerator[IndexingProgress, None]:
"""
索引目录中的代码文件
Args:
directory: 目录路径
exclude_patterns: 排除模式
include_patterns: 包含模式
progress_callback: 进度回调
Yields:
索引进度
"""
await self.initialize()
progress = IndexingProgress()
exclude_patterns = exclude_patterns or []
# 收集文件
files = self._collect_files(directory, exclude_patterns, include_patterns)
progress.total_files = len(files)
logger.info(f"Found {len(files)} files to index in {directory}")
yield progress
all_chunks: List[CodeChunk] = []
# 分块处理文件
for file_path in files:
progress.current_file = file_path
try:
relative_path = os.path.relpath(file_path, directory)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
if not content.strip():
progress.processed_files += 1
continue
# 限制文件大小
if len(content) > 500000: # 500KB
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, relative_path)
all_chunks.extend(chunks)
progress.processed_files += 1
progress.total_chunks = len(all_chunks)
if progress_callback:
progress_callback(progress)
yield progress
except Exception as e:
logger.warning(f"Error processing {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
logger.info(f"Created {len(all_chunks)} chunks from {len(files)} files")
# 批量嵌入和索引
if all_chunks:
await self._index_chunks(all_chunks, progress)
progress.indexed_chunks = len(all_chunks)
yield progress
async def index_files(
self,
files: List[Dict[str, str]],
base_path: str = "",
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
) -> AsyncGenerator[IndexingProgress, None]:
"""
索引文件列表
Args:
files: 文件列表 [{"path": "...", "content": "..."}]
base_path: 基础路径
progress_callback: 进度回调
Yields:
索引进度
"""
await self.initialize()
progress = IndexingProgress()
progress.total_files = len(files)
all_chunks: List[CodeChunk] = []
for file_info in files:
file_path = file_info.get("path", "")
content = file_info.get("content", "")
progress.current_file = file_path
try:
if not content.strip():
progress.processed_files += 1
continue
# 限制文件大小
if len(content) > 500000:
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, file_path)
all_chunks.extend(chunks)
progress.processed_files += 1
progress.total_chunks = len(all_chunks)
if progress_callback:
progress_callback(progress)
yield progress
except Exception as e:
logger.warning(f"Error processing {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
# 批量嵌入和索引
if all_chunks:
await self._index_chunks(all_chunks, progress)
progress.indexed_chunks = len(all_chunks)
yield progress
async def _index_chunks(self, chunks: List[CodeChunk], progress: IndexingProgress):
"""索引代码块"""
# 准备嵌入文本
texts = [chunk.to_embedding_text() for chunk in chunks]
logger.info(f"Generating embeddings for {len(texts)} chunks...")
# 批量嵌入
embeddings = await self.embedding_service.embed_batch(texts, batch_size=50)
# 准备元数据
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = [chunk.to_dict() for chunk in chunks]
# 添加到向量存储
logger.info(f"Adding {len(chunks)} chunks to vector store...")
await self.vector_store.add_documents(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
logger.info(f"Indexed {len(chunks)} chunks successfully")
def _collect_files(
self,
directory: str,
exclude_patterns: List[str],
include_patterns: Optional[List[str]],
) -> List[str]:
"""收集需要索引的文件"""
import fnmatch
files = []
for root, dirs, filenames in os.walk(directory):
# 过滤目录
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
for filename in filenames:
# 检查扩展名
ext = os.path.splitext(filename)[1].lower()
if ext not in TEXT_EXTENSIONS:
continue
# 检查排除文件
if filename in EXCLUDE_FILES:
continue
file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, directory)
# 检查排除模式
excluded = False
for pattern in exclude_patterns:
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
excluded = True
break
if excluded:
continue
# 检查包含模式
if include_patterns:
included = False
for pattern in include_patterns:
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
included = True
break
if not included:
continue
files.append(file_path)
return files
async def get_chunk_count(self) -> int:
"""获取已索引的代码块数量"""
await self.initialize()
return await self.vector_store.get_count()
async def clear(self):
"""清空索引"""
await self.initialize()
await self.vector_store.delete_collection()
self._initialized = False

View File

@ -0,0 +1,469 @@
"""
代码检索器
支持语义检索和混合检索
"""
import re
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from .embeddings import EmbeddingService
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
from .splitter import CodeChunk, ChunkType
logger = logging.getLogger(__name__)
@dataclass
class RetrievalResult:
"""检索结果"""
chunk_id: str
content: str
file_path: str
language: str
chunk_type: str
line_start: int
line_end: int
score: float # 相似度分数 (0-1, 越高越相似)
# 可选的元数据
name: Optional[str] = None
parent_name: Optional[str] = None
signature: Optional[str] = None
security_indicators: List[str] = field(default_factory=list)
# 原始元数据
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"chunk_id": self.chunk_id,
"content": self.content,
"file_path": self.file_path,
"language": self.language,
"chunk_type": self.chunk_type,
"line_start": self.line_start,
"line_end": self.line_end,
"score": self.score,
"name": self.name,
"parent_name": self.parent_name,
"signature": self.signature,
"security_indicators": self.security_indicators,
}
def to_context_string(self, include_metadata: bool = True) -> str:
"""转换为上下文字符串(用于 LLM 输入)"""
parts = []
if include_metadata:
header = f"File: {self.file_path}"
if self.line_start and self.line_end:
header += f" (lines {self.line_start}-{self.line_end})"
if self.name:
header += f"\n{self.chunk_type.title()}: {self.name}"
if self.parent_name:
header += f" in {self.parent_name}"
parts.append(header)
parts.append(f"```{self.language}\n{self.content}\n```")
return "\n".join(parts)
class CodeRetriever:
"""
代码检索器
支持语义检索关键字检索和混合检索
"""
def __init__(
self,
collection_name: str,
embedding_service: Optional[EmbeddingService] = None,
vector_store: Optional[VectorStore] = None,
persist_directory: Optional[str] = None,
):
"""
初始化检索器
Args:
collection_name: 向量集合名称
embedding_service: 嵌入服务
vector_store: 向量存储
persist_directory: 持久化目录
"""
self.collection_name = collection_name
self.embedding_service = embedding_service or EmbeddingService()
# 创建向量存储
if vector_store:
self.vector_store = vector_store
else:
try:
self.vector_store = ChromaVectorStore(
collection_name=collection_name,
persist_directory=persist_directory,
)
except ImportError:
logger.warning("Chroma not available, using in-memory store")
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
self._initialized = False
async def initialize(self):
"""初始化检索器"""
if not self._initialized:
await self.vector_store.initialize()
self._initialized = True
async def retrieve(
self,
query: str,
top_k: int = 10,
filter_file_path: Optional[str] = None,
filter_language: Optional[str] = None,
filter_chunk_type: Optional[str] = None,
min_score: float = 0.0,
) -> List[RetrievalResult]:
"""
语义检索
Args:
query: 查询文本
top_k: 返回数量
filter_file_path: 文件路径过滤
filter_language: 语言过滤
filter_chunk_type: 块类型过滤
min_score: 最小相似度分数
Returns:
检索结果列表
"""
await self.initialize()
# 生成查询嵌入
query_embedding = await self.embedding_service.embed(query)
# 构建过滤条件
where = {}
if filter_file_path:
where["file_path"] = filter_file_path
if filter_language:
where["language"] = filter_language
if filter_chunk_type:
where["chunk_type"] = filter_chunk_type
# 查询向量存储
raw_results = await self.vector_store.query(
query_embedding=query_embedding,
n_results=top_k * 2, # 多查一些,后面过滤
where=where if where else None,
)
# 转换结果
results = []
for i, (id_, doc, meta, dist) in enumerate(zip(
raw_results["ids"],
raw_results["documents"],
raw_results["metadatas"],
raw_results["distances"],
)):
# 将距离转换为相似度分数 (余弦距离)
score = 1 - dist
if score < min_score:
continue
# 解析安全指标(可能是 JSON 字符串)
security_indicators = meta.get("security_indicators", [])
if isinstance(security_indicators, str):
try:
import json
security_indicators = json.loads(security_indicators)
except:
security_indicators = []
result = RetrievalResult(
chunk_id=id_,
content=doc,
file_path=meta.get("file_path", ""),
language=meta.get("language", "text"),
chunk_type=meta.get("chunk_type", "unknown"),
line_start=meta.get("line_start", 0),
line_end=meta.get("line_end", 0),
score=score,
name=meta.get("name"),
parent_name=meta.get("parent_name"),
signature=meta.get("signature"),
security_indicators=security_indicators,
metadata=meta,
)
results.append(result)
# 按分数排序并截取
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
async def retrieve_by_file(
self,
file_path: str,
top_k: int = 50,
) -> List[RetrievalResult]:
"""
按文件路径检索
Args:
file_path: 文件路径
top_k: 返回数量
Returns:
该文件的所有代码块
"""
await self.initialize()
# 使用一个通用查询
query_embedding = await self.embedding_service.embed(f"code in {file_path}")
raw_results = await self.vector_store.query(
query_embedding=query_embedding,
n_results=top_k,
where={"file_path": file_path},
)
results = []
for id_, doc, meta, dist in zip(
raw_results["ids"],
raw_results["documents"],
raw_results["metadatas"],
raw_results["distances"],
):
result = RetrievalResult(
chunk_id=id_,
content=doc,
file_path=meta.get("file_path", ""),
language=meta.get("language", "text"),
chunk_type=meta.get("chunk_type", "unknown"),
line_start=meta.get("line_start", 0),
line_end=meta.get("line_end", 0),
score=1 - dist,
name=meta.get("name"),
parent_name=meta.get("parent_name"),
metadata=meta,
)
results.append(result)
# 按行号排序
results.sort(key=lambda x: x.line_start)
return results
async def retrieve_security_related(
self,
vulnerability_type: Optional[str] = None,
top_k: int = 20,
) -> List[RetrievalResult]:
"""
检索与安全相关的代码
Args:
vulnerability_type: 漏洞类型 sql_injection, xss
top_k: 返回数量
Returns:
安全相关的代码块
"""
# 根据漏洞类型构建查询
security_queries = {
"sql_injection": "SQL query execute database user input",
"xss": "HTML render user input innerHTML template",
"command_injection": "system exec command shell subprocess",
"path_traversal": "file path read open user input",
"ssrf": "HTTP request URL user input fetch",
"deserialization": "deserialize pickle yaml load object",
"auth_bypass": "authentication login password token session",
"hardcoded_secret": "password secret key token credential",
}
if vulnerability_type and vulnerability_type in security_queries:
query = security_queries[vulnerability_type]
else:
query = "security vulnerability dangerous function user input"
return await self.retrieve(query, top_k=top_k)
async def retrieve_function_context(
self,
function_name: str,
file_path: Optional[str] = None,
include_callers: bool = True,
include_callees: bool = True,
top_k: int = 10,
) -> Dict[str, List[RetrievalResult]]:
"""
检索函数上下文
Args:
function_name: 函数名
file_path: 文件路径可选
include_callers: 是否包含调用者
include_callees: 是否包含被调用者
top_k: 每类返回数量
Returns:
包含函数定义调用者被调用者的字典
"""
context = {
"definition": [],
"callers": [],
"callees": [],
}
# 查找函数定义
definition_query = f"function definition {function_name}"
definitions = await self.retrieve(
definition_query,
top_k=5,
filter_file_path=file_path,
)
# 过滤出真正的定义
for result in definitions:
if result.name == function_name or function_name in (result.content or ""):
context["definition"].append(result)
if include_callers:
# 查找调用此函数的代码
caller_query = f"calls {function_name} invoke {function_name}"
callers = await self.retrieve(caller_query, top_k=top_k)
for result in callers:
# 检查是否真的调用了这个函数
if re.search(rf'\b{re.escape(function_name)}\s*\(', result.content):
if result not in context["definition"]:
context["callers"].append(result)
if include_callees and context["definition"]:
# 从函数定义中提取调用的其他函数
for definition in context["definition"]:
calls = re.findall(r'\b(\w+)\s*\(', definition.content)
unique_calls = list(set(calls))[:5] # 限制数量
for call in unique_calls:
if call == function_name:
continue
callees = await self.retrieve(
f"function {call} definition",
top_k=2,
)
context["callees"].extend(callees)
return context
async def retrieve_similar_code(
self,
code_snippet: str,
top_k: int = 5,
exclude_file: Optional[str] = None,
) -> List[RetrievalResult]:
"""
检索相似的代码
Args:
code_snippet: 代码片段
top_k: 返回数量
exclude_file: 排除的文件
Returns:
相似代码列表
"""
results = await self.retrieve(
f"similar code: {code_snippet}",
top_k=top_k * 2,
)
if exclude_file:
results = [r for r in results if r.file_path != exclude_file]
return results[:top_k]
async def hybrid_retrieve(
self,
query: str,
keywords: Optional[List[str]] = None,
top_k: int = 10,
semantic_weight: float = 0.7,
) -> List[RetrievalResult]:
"""
混合检索语义 + 关键字
Args:
query: 查询文本
keywords: 额外的关键字
top_k: 返回数量
semantic_weight: 语义检索权重
Returns:
检索结果列表
"""
# 语义检索
semantic_results = await self.retrieve(query, top_k=top_k * 2)
# 如果有关键字,进行关键字过滤/增强
if keywords:
keyword_pattern = '|'.join(re.escape(kw) for kw in keywords)
enhanced_results = []
for result in semantic_results:
# 计算关键字匹配度
matches = len(re.findall(keyword_pattern, result.content, re.IGNORECASE))
keyword_score = min(1.0, matches / len(keywords))
# 混合分数
hybrid_score = (
semantic_weight * result.score +
(1 - semantic_weight) * keyword_score
)
result.score = hybrid_score
enhanced_results.append(result)
enhanced_results.sort(key=lambda x: x.score, reverse=True)
return enhanced_results[:top_k]
return semantic_results[:top_k]
def format_results_for_llm(
self,
results: List[RetrievalResult],
max_tokens: int = 4000,
include_metadata: bool = True,
) -> str:
"""
将检索结果格式化为 LLM 输入
Args:
results: 检索结果
max_tokens: 最大 Token
include_metadata: 是否包含元数据
Returns:
格式化的字符串
"""
if not results:
return "No relevant code found."
parts = []
total_tokens = 0
for i, result in enumerate(results):
context = result.to_context_string(include_metadata=include_metadata)
estimated_tokens = len(context) // 4
if total_tokens + estimated_tokens > max_tokens:
break
parts.append(f"### Code Block {i + 1} (Score: {result.score:.2f})\n{context}")
total_tokens += estimated_tokens
return "\n\n".join(parts)

View File

@ -0,0 +1,785 @@
"""
代码分块器 - 基于 Tree-sitter AST 的智能代码分块
使用先进的 Python 库实现专业级代码解析
"""
import re
import hashlib
import logging
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
logger = logging.getLogger(__name__)
class ChunkType(Enum):
"""代码块类型"""
FILE = "file"
MODULE = "module"
CLASS = "class"
FUNCTION = "function"
METHOD = "method"
INTERFACE = "interface"
STRUCT = "struct"
ENUM = "enum"
IMPORT = "import"
CONSTANT = "constant"
CONFIG = "config"
COMMENT = "comment"
DECORATOR = "decorator"
UNKNOWN = "unknown"
@dataclass
class CodeChunk:
"""代码块"""
id: str
content: str
file_path: str
language: str
chunk_type: ChunkType
# 位置信息
line_start: int = 0
line_end: int = 0
byte_start: int = 0
byte_end: int = 0
# 语义信息
name: Optional[str] = None
parent_name: Optional[str] = None
signature: Optional[str] = None
docstring: Optional[str] = None
# AST 信息
ast_type: Optional[str] = None
# 关联信息
imports: List[str] = field(default_factory=list)
calls: List[str] = field(default_factory=list)
dependencies: List[str] = field(default_factory=list)
definitions: List[str] = field(default_factory=list)
# 安全相关
security_indicators: List[str] = field(default_factory=list)
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
# Token 估算
estimated_tokens: int = 0
def __post_init__(self):
if not self.id:
self.id = self._generate_id()
if not self.estimated_tokens:
self.estimated_tokens = self._estimate_tokens()
def _generate_id(self) -> str:
content = f"{self.file_path}:{self.line_start}:{self.line_end}:{self.content[:100]}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _estimate_tokens(self) -> int:
# 使用 tiktoken 如果可用
try:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(self.content))
except ImportError:
return len(self.content) // 4
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"content": self.content,
"file_path": self.file_path,
"language": self.language,
"chunk_type": self.chunk_type.value,
"line_start": self.line_start,
"line_end": self.line_end,
"name": self.name,
"parent_name": self.parent_name,
"signature": self.signature,
"docstring": self.docstring,
"ast_type": self.ast_type,
"imports": self.imports,
"calls": self.calls,
"definitions": self.definitions,
"security_indicators": self.security_indicators,
"estimated_tokens": self.estimated_tokens,
"metadata": self.metadata,
}
def to_embedding_text(self) -> str:
"""生成用于嵌入的文本"""
parts = []
parts.append(f"File: {self.file_path}")
if self.name:
parts.append(f"{self.chunk_type.value.title()}: {self.name}")
if self.parent_name:
parts.append(f"In: {self.parent_name}")
if self.signature:
parts.append(f"Signature: {self.signature}")
if self.docstring:
parts.append(f"Description: {self.docstring[:300]}")
parts.append(f"Code:\n{self.content}")
return "\n".join(parts)
class TreeSitterParser:
"""
基于 Tree-sitter 的代码解析器
提供 AST 级别的代码分析
"""
# 语言映射
LANGUAGE_MAP = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "typescript",
".tsx": "tsx",
".java": "java",
".go": "go",
".rs": "rust",
".cpp": "cpp",
".c": "c",
".h": "c",
".hpp": "cpp",
".cs": "c_sharp",
".php": "php",
".rb": "ruby",
".kt": "kotlin",
".swift": "swift",
}
# 各语言的函数/类节点类型
DEFINITION_TYPES = {
"python": {
"class": ["class_definition"],
"function": ["function_definition"],
"method": ["function_definition"],
"import": ["import_statement", "import_from_statement"],
},
"javascript": {
"class": ["class_declaration", "class"],
"function": ["function_declaration", "function", "arrow_function", "method_definition"],
"import": ["import_statement"],
},
"typescript": {
"class": ["class_declaration", "class"],
"function": ["function_declaration", "function", "arrow_function", "method_definition"],
"interface": ["interface_declaration"],
"import": ["import_statement"],
},
"java": {
"class": ["class_declaration"],
"method": ["method_declaration", "constructor_declaration"],
"interface": ["interface_declaration"],
"import": ["import_declaration"],
},
"go": {
"struct": ["type_declaration"],
"function": ["function_declaration", "method_declaration"],
"interface": ["type_declaration"],
"import": ["import_declaration"],
},
}
def __init__(self):
self._parsers: Dict[str, Any] = {}
self._initialized = False
def _ensure_initialized(self, language: str) -> bool:
"""确保语言解析器已初始化"""
if language in self._parsers:
return True
try:
from tree_sitter_languages import get_parser, get_language
parser = get_parser(language)
self._parsers[language] = parser
return True
except ImportError:
logger.warning("tree-sitter-languages not installed, falling back to regex parsing")
return False
except Exception as e:
logger.warning(f"Failed to load tree-sitter parser for {language}: {e}")
return False
def parse(self, code: str, language: str) -> Optional[Any]:
"""解析代码返回 AST"""
if not self._ensure_initialized(language):
return None
parser = self._parsers.get(language)
if not parser:
return None
try:
tree = parser.parse(code.encode())
return tree
except Exception as e:
logger.warning(f"Failed to parse code: {e}")
return None
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
"""从 AST 提取定义"""
if tree is None:
return []
definitions = []
definition_types = self.DEFINITION_TYPES.get(language, {})
def traverse(node, parent_name=None):
node_type = node.type
# 检查是否是定义节点
for def_category, types in definition_types.items():
if node_type in types:
name = self._extract_name(node, language)
definitions.append({
"type": def_category,
"name": name,
"parent_name": parent_name,
"start_point": node.start_point,
"end_point": node.end_point,
"start_byte": node.start_byte,
"end_byte": node.end_byte,
"node_type": node_type,
})
# 对于类,继续遍历子节点找方法
if def_category == "class":
for child in node.children:
traverse(child, name)
return
# 继续遍历子节点
for child in node.children:
traverse(child, parent_name)
traverse(tree.root_node)
return definitions
def _extract_name(self, node: Any, language: str) -> Optional[str]:
"""从节点提取名称"""
# 查找 identifier 子节点
for child in node.children:
if child.type in ["identifier", "name", "type_identifier", "property_identifier"]:
return child.text.decode() if isinstance(child.text, bytes) else child.text
# 对于某些语言的特殊处理
if language == "python":
for child in node.children:
if child.type == "name":
return child.text.decode() if isinstance(child.text, bytes) else child.text
return None
class CodeSplitter:
"""
高级代码分块器
使用 Tree-sitter 进行 AST 解析支持多种编程语言
"""
# 危险函数/模式(用于安全指标)
SECURITY_PATTERNS = {
"python": [
(r"\bexec\s*\(", "exec"),
(r"\beval\s*\(", "eval"),
(r"\bcompile\s*\(", "compile"),
(r"\bos\.system\s*\(", "os_system"),
(r"\bsubprocess\.", "subprocess"),
(r"\bcursor\.execute\s*\(", "sql_execute"),
(r"\.execute\s*\(.*%", "sql_format"),
(r"\bpickle\.loads?\s*\(", "pickle"),
(r"\byaml\.load\s*\(", "yaml_load"),
(r"\brequests?\.", "http_request"),
(r"password\s*=", "password_assign"),
(r"secret\s*=", "secret_assign"),
(r"api_key\s*=", "api_key_assign"),
],
"javascript": [
(r"\beval\s*\(", "eval"),
(r"\bFunction\s*\(", "function_constructor"),
(r"innerHTML\s*=", "innerHTML"),
(r"outerHTML\s*=", "outerHTML"),
(r"document\.write\s*\(", "document_write"),
(r"\.exec\s*\(", "exec"),
(r"\.query\s*\(.*\+", "sql_concat"),
(r"password\s*[=:]", "password_assign"),
(r"apiKey\s*[=:]", "api_key_assign"),
],
"java": [
(r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"),
(r"ProcessBuilder", "process_builder"),
(r"\.executeQuery\s*\(.*\+", "sql_concat"),
(r"ObjectInputStream", "deserialization"),
(r"XMLDecoder", "xml_decoder"),
(r"password\s*=", "password_assign"),
],
"go": [
(r"exec\.Command\s*\(", "exec_command"),
(r"\.Query\s*\(.*\+", "sql_concat"),
(r"\.Exec\s*\(.*\+", "sql_concat"),
(r"template\.HTML\s*\(", "unsafe_html"),
(r"password\s*=", "password_assign"),
],
"php": [
(r"\beval\s*\(", "eval"),
(r"\bexec\s*\(", "exec"),
(r"\bsystem\s*\(", "system"),
(r"\bshell_exec\s*\(", "shell_exec"),
(r"\$_GET\[", "get_input"),
(r"\$_POST\[", "post_input"),
(r"\$_REQUEST\[", "request_input"),
],
}
def __init__(
self,
max_chunk_size: int = 1500,
min_chunk_size: int = 100,
overlap_size: int = 50,
preserve_structure: bool = True,
use_tree_sitter: bool = True,
):
self.max_chunk_size = max_chunk_size
self.min_chunk_size = min_chunk_size
self.overlap_size = overlap_size
self.preserve_structure = preserve_structure
self.use_tree_sitter = use_tree_sitter
self._ts_parser = TreeSitterParser() if use_tree_sitter else None
def detect_language(self, file_path: str) -> str:
"""检测编程语言"""
ext = Path(file_path).suffix.lower()
return TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
def split_file(
self,
content: str,
file_path: str,
language: Optional[str] = None
) -> List[CodeChunk]:
"""
分割单个文件
Args:
content: 文件内容
file_path: 文件路径
language: 编程语言可选
Returns:
代码块列表
"""
if not content or not content.strip():
return []
if language is None:
language = self.detect_language(file_path)
chunks = []
try:
# 尝试使用 Tree-sitter 解析
if self.use_tree_sitter and self._ts_parser:
tree = self._ts_parser.parse(content, language)
if tree:
chunks = self._split_by_ast(content, file_path, language, tree)
# 如果 AST 解析失败或没有结果,使用增强的正则解析
if not chunks:
chunks = self._split_by_enhanced_regex(content, file_path, language)
# 如果还是没有,使用基于行的分块
if not chunks:
chunks = self._split_by_lines(content, file_path, language)
# 后处理:提取安全指标
for chunk in chunks:
chunk.security_indicators = self._extract_security_indicators(
chunk.content, language
)
# 后处理:使用语义分析增强
self._enrich_chunks_with_semantics(chunks, content, language)
except Exception as e:
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
chunks = self._split_by_lines(content, file_path, language)
return chunks
def _split_by_ast(
self,
content: str,
file_path: str,
language: str,
tree: Any
) -> List[CodeChunk]:
"""基于 AST 分块"""
chunks = []
lines = content.split('\n')
# 提取定义
definitions = self._ts_parser.extract_definitions(tree, content, language)
if not definitions:
return []
# 为每个定义创建代码块
for defn in definitions:
start_line = defn["start_point"][0]
end_line = defn["end_point"][0]
# 提取代码内容
chunk_lines = lines[start_line:end_line + 1]
chunk_content = '\n'.join(chunk_lines)
if len(chunk_content.strip()) < self.min_chunk_size // 4:
continue
chunk_type = ChunkType.CLASS if defn["type"] == "class" else \
ChunkType.FUNCTION if defn["type"] in ["function", "method"] else \
ChunkType.INTERFACE if defn["type"] == "interface" else \
ChunkType.STRUCT if defn["type"] == "struct" else \
ChunkType.IMPORT if defn["type"] == "import" else \
ChunkType.MODULE
chunk = CodeChunk(
id="",
content=chunk_content,
file_path=file_path,
language=language,
chunk_type=chunk_type,
line_start=start_line + 1,
line_end=end_line + 1,
byte_start=defn["start_byte"],
byte_end=defn["end_byte"],
name=defn.get("name"),
parent_name=defn.get("parent_name"),
ast_type=defn.get("node_type"),
)
# 如果块太大,进一步分割
if chunk.estimated_tokens > self.max_chunk_size:
sub_chunks = self._split_large_chunk(chunk)
chunks.extend(sub_chunks)
else:
chunks.append(chunk)
return chunks
def _split_by_enhanced_regex(
self,
content: str,
file_path: str,
language: str
) -> List[CodeChunk]:
"""增强的正则表达式分块(支持更多语言)"""
chunks = []
lines = content.split('\n')
# 各语言的定义模式
patterns = {
"python": [
(r"^(\s*)class\s+(\w+)(?:\s*\([^)]*\))?\s*:", ChunkType.CLASS),
(r"^(\s*)(?:async\s+)?def\s+(\w+)\s*\([^)]*\)\s*(?:->[^:]+)?:", ChunkType.FUNCTION),
],
"javascript": [
(r"^(\s*)(?:export\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)\s*\(", ChunkType.FUNCTION),
(r"^(\s*)(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>", ChunkType.FUNCTION),
],
"typescript": [
(r"^(\s*)(?:export\s+)?(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:export\s+)?interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)", ChunkType.FUNCTION),
],
"java": [
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:public|private|protected)?\s*interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+[\w,\s]+)?\s*\{", ChunkType.METHOD),
],
"go": [
(r"^type\s+(\w+)\s+struct\s*\{", ChunkType.STRUCT),
(r"^type\s+(\w+)\s+interface\s*\{", ChunkType.INTERFACE),
(r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\([^)]*\)", ChunkType.FUNCTION),
],
"php": [
(r"^(\s*)(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?function\s+(\w+)", ChunkType.FUNCTION),
],
}
lang_patterns = patterns.get(language, [])
if not lang_patterns:
return []
# 找到所有定义的位置
definitions = []
for i, line in enumerate(lines):
for pattern, chunk_type in lang_patterns:
match = re.match(pattern, line)
if match:
indent = len(match.group(1)) if match.lastindex >= 1 else 0
name = match.group(2) if match.lastindex >= 2 else None
definitions.append({
"line": i,
"indent": indent,
"name": name,
"type": chunk_type,
})
break
if not definitions:
return []
# 计算每个定义的范围
for i, defn in enumerate(definitions):
start_line = defn["line"]
base_indent = defn["indent"]
# 查找结束位置
end_line = len(lines) - 1
for j in range(start_line + 1, len(lines)):
line = lines[j]
if line.strip():
current_indent = len(line) - len(line.lstrip())
# 如果缩进回到基础级别,检查是否是下一个定义
if current_indent <= base_indent:
# 检查是否是下一个定义
is_next_def = any(d["line"] == j for d in definitions)
if is_next_def or (current_indent < base_indent):
end_line = j - 1
break
chunk_content = '\n'.join(lines[start_line:end_line + 1])
if len(chunk_content.strip()) < 10:
continue
chunk = CodeChunk(
id="",
content=chunk_content,
file_path=file_path,
language=language,
chunk_type=defn["type"],
line_start=start_line + 1,
line_end=end_line + 1,
name=defn.get("name"),
)
if chunk.estimated_tokens > self.max_chunk_size:
sub_chunks = self._split_large_chunk(chunk)
chunks.extend(sub_chunks)
else:
chunks.append(chunk)
return chunks
def _split_by_lines(
self,
content: str,
file_path: str,
language: str
) -> List[CodeChunk]:
"""基于行数分块(回退方案)"""
chunks = []
lines = content.split('\n')
# 估算每行 Token 数
total_tokens = len(content) // 4
avg_tokens_per_line = max(1, total_tokens // max(1, len(lines)))
lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line)
overlap_lines = self.overlap_size // avg_tokens_per_line
for i in range(0, len(lines), lines_per_chunk - overlap_lines):
end = min(i + lines_per_chunk, len(lines))
chunk_content = '\n'.join(lines[i:end])
if len(chunk_content.strip()) < 10:
continue
chunk = CodeChunk(
id="",
content=chunk_content,
file_path=file_path,
language=language,
chunk_type=ChunkType.MODULE,
line_start=i + 1,
line_end=end,
)
chunks.append(chunk)
if end >= len(lines):
break
return chunks
def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]:
"""分割过大的代码块"""
sub_chunks = []
lines = chunk.content.split('\n')
avg_tokens_per_line = max(1, chunk.estimated_tokens // max(1, len(lines)))
lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line)
for i in range(0, len(lines), lines_per_chunk):
end = min(i + lines_per_chunk, len(lines))
sub_content = '\n'.join(lines[i:end])
if len(sub_content.strip()) < 10:
continue
sub_chunk = CodeChunk(
id="",
content=sub_content,
file_path=chunk.file_path,
language=chunk.language,
chunk_type=chunk.chunk_type,
line_start=chunk.line_start + i,
line_end=chunk.line_start + end - 1,
name=chunk.name,
parent_name=chunk.parent_name,
)
sub_chunks.append(sub_chunk)
return sub_chunks if sub_chunks else [chunk]
def _extract_security_indicators(self, content: str, language: str) -> List[str]:
"""提取安全相关指标"""
indicators = []
patterns = self.SECURITY_PATTERNS.get(language, [])
# 添加通用模式
common_patterns = [
(r"password", "password"),
(r"secret", "secret"),
(r"api[_-]?key", "api_key"),
(r"token", "token"),
(r"private[_-]?key", "private_key"),
(r"credential", "credential"),
]
all_patterns = patterns + common_patterns
for pattern, name in all_patterns:
try:
if re.search(pattern, content, re.IGNORECASE):
if name not in indicators:
indicators.append(name)
except re.error:
continue
return indicators[:15]
def _enrich_chunks_with_semantics(
self,
chunks: List[CodeChunk],
full_content: str,
language: str
):
"""使用语义分析增强代码块"""
# 提取导入
imports = self._extract_imports(full_content, language)
for chunk in chunks:
# 添加相关导入
chunk.imports = self._filter_relevant_imports(imports, chunk.content)
# 提取函数调用
chunk.calls = self._extract_function_calls(chunk.content, language)
# 提取定义
chunk.definitions = self._extract_definitions(chunk.content, language)
def _extract_imports(self, content: str, language: str) -> List[str]:
"""提取导入语句"""
imports = []
patterns = {
"python": [
r"^import\s+([\w.]+)",
r"^from\s+([\w.]+)\s+import",
],
"javascript": [
r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
r"require\s*\(['\"]([^'\"]+)['\"]\)",
],
"typescript": [
r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
],
"java": [
r"^import\s+([\w.]+);",
],
"go": [
r"['\"]([^'\"]+)['\"]",
],
}
for pattern in patterns.get(language, []):
matches = re.findall(pattern, content, re.MULTILINE)
imports.extend(matches)
return list(set(imports))
def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]:
"""过滤与代码块相关的导入"""
relevant = []
for imp in all_imports:
module_name = imp.split('.')[-1]
if re.search(rf'\b{re.escape(module_name)}\b', chunk_content):
relevant.append(imp)
return relevant[:20]
def _extract_function_calls(self, content: str, language: str) -> List[str]:
"""提取函数调用"""
pattern = r'\b(\w+)\s*\('
matches = re.findall(pattern, content)
keywords = {
"python": {"if", "for", "while", "with", "def", "class", "return", "except", "print", "assert", "lambda"},
"javascript": {"if", "for", "while", "switch", "function", "return", "catch", "console", "async", "await"},
"java": {"if", "for", "while", "switch", "return", "catch", "throw", "new"},
"go": {"if", "for", "switch", "return", "func", "go", "defer"},
}
lang_keywords = keywords.get(language, set())
calls = [m for m in matches if m not in lang_keywords]
return list(set(calls))[:30]
def _extract_definitions(self, content: str, language: str) -> List[str]:
"""提取定义的标识符"""
definitions = []
patterns = {
"python": [
r"def\s+(\w+)\s*\(",
r"class\s+(\w+)",
r"(\w+)\s*=\s*",
],
"javascript": [
r"function\s+(\w+)",
r"(?:const|let|var)\s+(\w+)",
r"class\s+(\w+)",
],
}
for pattern in patterns.get(language, []):
matches = re.findall(pattern, content)
definitions.extend(matches)
return list(set(definitions))[:20]

View File

@ -17,3 +17,51 @@ reportlab>=4.0.0
weasyprint>=66.0 weasyprint>=66.0
jinja2>=3.1.6 jinja2>=3.1.6
json-repair>=0.30.0 json-repair>=0.30.0
# ============ Agent 模块依赖 ============
# LangChain 核心
langchain>=0.1.0
langchain-community>=0.0.20
langchain-openai>=0.0.5
# LangGraph (状态图工作流)
langgraph>=0.0.40
# 向量数据库
chromadb>=0.4.22
# Token 计算
tiktoken>=0.5.2
# Docker 沙箱
docker>=7.0.0
# 异步文件操作
aiofiles>=23.2.1
# SSE 流
sse-starlette>=1.8.2
# ============ 代码解析 (高级库) ============
# Tree-sitter AST 解析
tree-sitter>=0.21.0
tree-sitter-languages>=1.10.0
# 通用代码解析
pygments>=2.17.0
# ============ 外部安全工具 (可选安装) ============
# 这些工具可以通过 pip 安装,或使用系统包管理器
# Python 安全扫描
bandit>=1.7.0
safety>=2.3.0
# 静态分析 (需要单独安装 semgrep CLI)
# pip install semgrep
# 依赖漏洞扫描
pip-audit>=2.6.0

62
docker/sandbox/Dockerfile Normal file
View File

@ -0,0 +1,62 @@
# DeepAudit Agent Sandbox
# 安全沙箱环境用于漏洞验证和 PoC 执行
FROM python:3.11-slim-bookworm
LABEL maintainer="XCodeReviewer Team"
LABEL description="Secure sandbox environment for vulnerability verification"
# 安装基本工具
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
wget \
netcat-openbsd \
dnsutils \
iputils-ping \
ca-certificates \
git \
&& rm -rf /var/lib/apt/lists/*
# 安装 Node.js (用于 JavaScript/TypeScript 代码执行)
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs \
&& rm -rf /var/lib/apt/lists/*
# 安装常用的安全测试 Python 库
RUN pip install --no-cache-dir \
requests \
httpx \
aiohttp \
beautifulsoup4 \
lxml \
pycryptodome \
paramiko \
pyjwt \
python-jose \
sqlparse
# 创建非 root 用户
RUN groupadd -g 1000 sandbox && \
useradd -u 1000 -g sandbox -m -s /bin/bash sandbox
# 创建工作目录
RUN mkdir -p /workspace /tmp/sandbox && \
chown -R sandbox:sandbox /workspace /tmp/sandbox
# 设置环境变量
ENV HOME=/home/sandbox
ENV PATH=/home/sandbox/.local/bin:$PATH
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
# 限制 Python 导入路径
ENV PYTHONPATH=/workspace
# 切换到非 root 用户
USER sandbox
WORKDIR /workspace
# 默认命令
CMD ["/bin/bash"]

25
docker/sandbox/build.sh Normal file
View File

@ -0,0 +1,25 @@
#!/bin/bash
# 构建沙箱镜像
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
IMAGE_NAME="deepaudit-sandbox"
IMAGE_TAG="latest"
echo "Building sandbox image: ${IMAGE_NAME}:${IMAGE_TAG}"
docker build \
-t "${IMAGE_NAME}:${IMAGE_TAG}" \
-f "${SCRIPT_DIR}/Dockerfile" \
"${SCRIPT_DIR}"
echo "Build complete: ${IMAGE_NAME}:${IMAGE_TAG}"
# 验证镜像
echo "Verifying image..."
docker run --rm "${IMAGE_NAME}:${IMAGE_TAG}" python3 --version
docker run --rm "${IMAGE_NAME}:${IMAGE_TAG}" node --version
echo "Sandbox image ready!"

266
docker/sandbox/seccomp.json Normal file
View File

@ -0,0 +1,266 @@
{
"defaultAction": "SCMP_ACT_ERRNO",
"defaultErrnoRet": 1,
"archMap": [
{
"architecture": "SCMP_ARCH_X86_64",
"subArchitectures": [
"SCMP_ARCH_X86",
"SCMP_ARCH_X32"
]
},
{
"architecture": "SCMP_ARCH_AARCH64",
"subArchitectures": [
"SCMP_ARCH_ARM"
]
}
],
"syscalls": [
{
"names": [
"accept",
"accept4",
"access",
"arch_prctl",
"bind",
"brk",
"capget",
"capset",
"chdir",
"chmod",
"chown",
"clock_getres",
"clock_gettime",
"clock_nanosleep",
"clone",
"clone3",
"close",
"connect",
"copy_file_range",
"dup",
"dup2",
"dup3",
"epoll_create",
"epoll_create1",
"epoll_ctl",
"epoll_pwait",
"epoll_pwait2",
"epoll_wait",
"eventfd",
"eventfd2",
"execve",
"execveat",
"exit",
"exit_group",
"faccessat",
"faccessat2",
"fadvise64",
"fallocate",
"fchdir",
"fchmod",
"fchmodat",
"fchown",
"fchownat",
"fcntl",
"fdatasync",
"fgetxattr",
"flistxattr",
"flock",
"fork",
"fstat",
"fstatfs",
"fsync",
"ftruncate",
"futex",
"getcwd",
"getdents",
"getdents64",
"getegid",
"geteuid",
"getgid",
"getgroups",
"getpeername",
"getpgid",
"getpgrp",
"getpid",
"getppid",
"getpriority",
"getrandom",
"getresgid",
"getresuid",
"getrlimit",
"getrusage",
"getsid",
"getsockname",
"getsockopt",
"gettid",
"gettimeofday",
"getuid",
"getxattr",
"inotify_add_watch",
"inotify_init",
"inotify_init1",
"inotify_rm_watch",
"ioctl",
"kill",
"lchown",
"lgetxattr",
"link",
"linkat",
"listen",
"listxattr",
"llistxattr",
"lseek",
"lstat",
"madvise",
"membarrier",
"memfd_create",
"mincore",
"mkdir",
"mkdirat",
"mknod",
"mknodat",
"mlock",
"mlock2",
"mlockall",
"mmap",
"mprotect",
"mremap",
"msync",
"munlock",
"munlockall",
"munmap",
"name_to_handle_at",
"nanosleep",
"newfstatat",
"open",
"openat",
"openat2",
"pause",
"pipe",
"pipe2",
"poll",
"ppoll",
"prctl",
"pread64",
"preadv",
"preadv2",
"prlimit64",
"pselect6",
"pwrite64",
"pwritev",
"pwritev2",
"read",
"readahead",
"readlink",
"readlinkat",
"readv",
"recv",
"recvfrom",
"recvmmsg",
"recvmsg",
"rename",
"renameat",
"renameat2",
"restart_syscall",
"rmdir",
"rseq",
"rt_sigaction",
"rt_sigpending",
"rt_sigprocmask",
"rt_sigqueueinfo",
"rt_sigreturn",
"rt_sigsuspend",
"rt_sigtimedwait",
"rt_tgsigqueueinfo",
"sched_getaffinity",
"sched_getattr",
"sched_getparam",
"sched_get_priority_max",
"sched_get_priority_min",
"sched_getscheduler",
"sched_rr_get_interval",
"sched_setaffinity",
"sched_setattr",
"sched_setparam",
"sched_setscheduler",
"sched_yield",
"seccomp",
"select",
"semctl",
"semget",
"semop",
"semtimedop",
"send",
"sendfile",
"sendmmsg",
"sendmsg",
"sendto",
"set_robust_list",
"set_tid_address",
"setgid",
"setgroups",
"setitimer",
"setpgid",
"setpriority",
"setregid",
"setresgid",
"setresuid",
"setreuid",
"setsid",
"setsockopt",
"setuid",
"shmat",
"shmctl",
"shmdt",
"shmget",
"shutdown",
"sigaltstack",
"signalfd",
"signalfd4",
"socket",
"socketpair",
"splice",
"stat",
"statfs",
"statx",
"symlink",
"symlinkat",
"sync",
"sync_file_range",
"syncfs",
"sysinfo",
"tee",
"tgkill",
"time",
"timer_create",
"timer_delete",
"timer_getoverrun",
"timer_gettime",
"timer_settime",
"timerfd_create",
"timerfd_gettime",
"timerfd_settime",
"times",
"tkill",
"truncate",
"umask",
"uname",
"unlink",
"unlinkat",
"utime",
"utimensat",
"utimes",
"vfork",
"wait4",
"waitid",
"waitpid",
"write",
"writev"
],
"action": "SCMP_ACT_ALLOW"
}
]
}

298
docs/AGENT_AUDIT.md Normal file
View File

@ -0,0 +1,298 @@
# DeepAudit Agent 审计模块
## 概述
Agent 审计模块是 DeepAudit 的高级安全审计功能,基于 **LangGraph** 状态图构建的混合 AI Agent 架构,实现自主代码安全分析和漏洞验证。
## LangGraph 工作流架构
```
┌─────────────────────────────────────────────────────────────────────┐
│ LangGraph 审计工作流 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ START │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────┐│
│ │ Recon Node (信息收集) ││
│ │ • 项目结构分析 • 技术栈识别 ││
│ │ • 入口点发现 • 依赖扫描 ││
│ │ ││
│ │ 使用工具: list_files, npm_audit, safety_scan, gitleaks_scan ││
│ └────────────────────────────┬───────────────────────────────────┘│
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────┐│
│ │ Analysis Node (漏洞分析) ││
│ │ • Semgrep 静态分析 • RAG 语义搜索 ││
│ │ • 模式匹配 • LLM 深度分析 ││
│ │ • 数据流追踪 ││
│ │ ◄─────┐ ││
│ │ 使用工具: semgrep_scan, bandit_scan, rag_query, │ ││
│ │ code_analysis, pattern_match │ ││
│ └────────────────────────────┬──────────────────────────┘───────┘│
│ │ │ │
│ ▼ │ │
│ ┌────────────────────────────────────────────────────────────────┐│
│ │ Verification Node (漏洞验证) ││
│ │ • LLM 漏洞验证 • 沙箱测试 ││
│ │ • PoC 生成 • 误报过滤 ││
│ │ ────────┘ ││
│ │ 使用工具: vulnerability_validation, sandbox_exec, ││
│ │ verify_vulnerability ││
│ └────────────────────────────┬───────────────────────────────────┘│
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────┐│
│ │ Report Node (报告生成) ││
│ │ • 漏洞汇总 • 安全评分 ││
│ │ • 修复建议 • 统计分析 ││
│ └────────────────────────────┬───────────────────────────────────┘│
│ │ │
│ ▼ │
│ END │
│ │
└────────────────────────────────────────────────────────────────────┘
状态流转:
• Recon → Analysis: 收集到入口点后进入分析
• Analysis → Analysis: 发现较多问题时继续迭代
• Analysis → Verification: 有发现时进入验证
• Verification → Analysis: 误报率高时回溯分析
• Verification → Report: 验证完成后生成报告
```
## 核心特性
### 1. LangGraph 状态图
- **声明式工作流**: 使用图结构定义 Agent 协作流程
- **状态自动合并**: `Annotated[List, operator.add]` 实现发现累加
- **条件路由**: 基于状态动态决定下一步
- **检查点恢复**: 支持任务中断后继续
### 2. Agent 工具集
#### 内置工具
| 工具 | 功能 | 节点 |
|------|------|------|
| `list_files` | 目录浏览 | Recon |
| `read_file` | 文件读取 | All |
| `search_code` | 代码搜索 | Analysis |
| `rag_query` | 语义检索 | Analysis |
| `security_search` | 安全代码搜索 | Analysis |
| `function_context` | 函数上下文 | Analysis |
| `pattern_match` | 模式匹配 | Analysis |
| `code_analysis` | LLM 分析 | Analysis |
| `dataflow_analysis` | 数据流追踪 | Analysis |
| `vulnerability_validation` | 漏洞验证 | Verification |
| `sandbox_exec` | 沙箱执行 | Verification |
| `verify_vulnerability` | 自动验证 | Verification |
#### 外部安全工具
| 工具 | 功能 | 适用场景 |
|------|------|----------|
| `semgrep_scan` | Semgrep 静态分析 | 多语言快速扫描 |
| `bandit_scan` | Bandit Python 扫描 | Python 安全分析 |
| `gitleaks_scan` | Gitleaks 密钥检测 | 密钥泄露检测 |
| `trufflehog_scan` | TruffleHog 扫描 | 深度密钥扫描 |
| `npm_audit` | npm 依赖审计 | Node.js 依赖漏洞 |
| `safety_scan` | Safety Python 审计 | Python 依赖漏洞 |
| `osv_scan` | OSV 漏洞扫描 | 多语言依赖漏洞 |
### 3. RAG 系统
- **代码分块**: 基于 Tree-sitter AST 的智能分块
- **向量存储**: ChromaDB 持久化
- **多语言支持**: Python, JavaScript, TypeScript, Java, Go, PHP, Rust 等
- **嵌入模型**: 独立配置,支持 OpenAI、Ollama、Cohere、HuggingFace
### 4. 安全沙箱
- **Docker 隔离**: 安全容器执行
- **资源限制**: 内存、CPU 限制
- **网络隔离**: 可配置网络访问
- **seccomp 策略**: 系统调用白名单
## 配置
### 环境变量
```bash
# LLM 配置
DEFAULT_LLM_MODEL=gpt-4-turbo-preview
LLM_API_KEY=your-api-key
LLM_BASE_URL=https://api.openai.com/v1
# 嵌入模型配置(独立于 LLM
EMBEDDING_PROVIDER=openai
EMBEDDING_MODEL=text-embedding-3-small
# 向量数据库
VECTOR_DB_PATH=./data/vectordb
# 沙箱配置
SANDBOX_IMAGE=deepaudit-sandbox:latest
SANDBOX_MEMORY_LIMIT=512m
SANDBOX_CPU_LIMIT=1.0
SANDBOX_NETWORK_DISABLED=true
```
### Agent 任务配置
```json
{
"target_vulnerabilities": [
"sql_injection",
"xss",
"command_injection",
"path_traversal",
"ssrf"
],
"verification_level": "sandbox",
"exclude_patterns": ["node_modules", "__pycache__", ".git"],
"max_iterations": 3,
"timeout_seconds": 1800
}
```
## 部署
### 1. 安装依赖
```bash
cd backend
pip install -r requirements.txt
# 可选:安装外部工具
pip install semgrep bandit safety
brew install gitleaks trufflehog osv-scanner # macOS
```
### 2. 构建沙箱镜像
```bash
cd docker/sandbox
./build.sh
```
### 3. 数据库迁移
```bash
alembic upgrade head
```
### 4. 启动服务
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000
```
## API 接口
### 创建任务
```http
POST /api/v1/agent-tasks/
Content-Type: application/json
{
"project_id": "xxx",
"name": "安全审计",
"target_vulnerabilities": ["sql_injection", "xss"],
"verification_level": "sandbox",
"max_iterations": 3
}
```
### 事件流
```http
GET /api/v1/agent-tasks/{task_id}/events
Accept: text/event-stream
```
### 获取发现
```http
GET /api/v1/agent-tasks/{task_id}/findings?verified_only=true
```
### 任务摘要
```http
GET /api/v1/agent-tasks/{task_id}/summary
```
## 支持的漏洞类型
| 类型 | 说明 |
|------|------|
| `sql_injection` | SQL 注入 |
| `xss` | 跨站脚本 |
| `command_injection` | 命令注入 |
| `path_traversal` | 路径遍历 |
| `ssrf` | 服务端请求伪造 |
| `xxe` | XML 外部实体 |
| `insecure_deserialization` | 不安全反序列化 |
| `hardcoded_secret` | 硬编码密钥 |
| `weak_crypto` | 弱加密 |
| `authentication_bypass` | 认证绕过 |
| `authorization_bypass` | 授权绕过 |
| `idor` | 不安全直接对象引用 |
## 目录结构
```
backend/app/services/agent/
├── __init__.py # 模块导出
├── event_manager.py # 事件管理
├── agents/ # Agent 实现
│ ├── __init__.py
│ ├── base.py # Agent 基类
│ ├── recon.py # 信息收集 Agent
│ ├── analysis.py # 漏洞分析 Agent
│ ├── verification.py # 漏洞验证 Agent
│ └── orchestrator.py # 编排 Agent
├── graph/ # LangGraph 工作流
│ ├── __init__.py
│ ├── audit_graph.py # 状态定义和图构建
│ ├── nodes.py # 节点实现
│ └── runner.py # 执行器
├── tools/ # Agent 工具
│ ├── __init__.py
│ ├── base.py # 工具基类
│ ├── rag_tool.py # RAG 工具
│ ├── pattern_tool.py # 模式匹配工具
│ ├── code_analysis_tool.py
│ ├── file_tool.py # 文件操作
│ ├── sandbox_tool.py # 沙箱工具
│ └── external_tools.py # 外部安全工具
└── prompts/ # 系统提示词
├── __init__.py
└── system_prompts.py
```
## 故障排除
### 沙箱镜像检查
```bash
docker images | grep deepaudit-sandbox
```
### 日志查看
```bash
tail -f logs/agent.log
```
### 常见问题
1. **RAG 初始化失败**: 检查嵌入模型配置和 API Key
2. **沙箱启动失败**: 确保 Docker 正常运行
3. **外部工具不可用**: 检查 semgrep/bandit 等是否已安装

View File

@ -54,3 +54,4 @@ EXPOSE 3000
ENTRYPOINT ["/docker-entrypoint.sh"] ENTRYPOINT ["/docker-entrypoint.sh"]
CMD ["serve", "-s", "dist", "-l", "3000"] CMD ["serve", "-s", "dist", "-l", "3000"]

View File

@ -17,3 +17,4 @@ export const ProtectedRoute = () => {
}; };

View File

@ -5,6 +5,7 @@ import RecycleBin from "@/pages/RecycleBin";
import InstantAnalysis from "@/pages/InstantAnalysis"; import InstantAnalysis from "@/pages/InstantAnalysis";
import AuditTasks from "@/pages/AuditTasks"; import AuditTasks from "@/pages/AuditTasks";
import TaskDetail from "@/pages/TaskDetail"; import TaskDetail from "@/pages/TaskDetail";
import AgentAudit from "@/pages/AgentAudit";
import AdminDashboard from "@/pages/AdminDashboard"; import AdminDashboard from "@/pages/AdminDashboard";
import LogsPage from "@/pages/LogsPage"; import LogsPage from "@/pages/LogsPage";
import Account from "@/pages/Account"; import Account from "@/pages/Account";
@ -56,6 +57,12 @@ const routes: RouteConfig[] = [
element: <TaskDetail />, element: <TaskDetail />,
visible: false, visible: false,
}, },
{
name: "Agent审计",
path: "/agent-audit/:taskId",
element: <AgentAudit />,
visible: false,
},
{ {
name: "审计规则", name: "审计规则",
path: "/audit-rules", path: "/audit-rules",

View File

@ -0,0 +1,158 @@
/**
* Agent
* Agent
*/
import { Bot, Zap, CheckCircle2, Clock, Shield, Code } from "lucide-react";
import { cn } from "@/shared/utils/utils";
export type AuditMode = "fast" | "agent";
interface AgentModeSelectorProps {
value: AuditMode;
onChange: (mode: AuditMode) => void;
disabled?: boolean;
}
export default function AgentModeSelector({
value,
onChange,
disabled = false,
}: AgentModeSelectorProps) {
return (
<div className="space-y-3">
<div className="flex items-center gap-2 mb-2">
<Shield className="w-4 h-4 text-indigo-700" />
<span className="font-mono text-sm font-bold text-indigo-900 uppercase">
</span>
</div>
<div className="grid grid-cols-2 gap-3">
{/* 快速审计模式 */}
<label
className={cn(
"relative flex flex-col p-4 border-2 cursor-pointer transition-all rounded-none",
value === "fast"
? "border-amber-500 bg-amber-50"
: "border-gray-300 hover:border-gray-400 bg-white",
disabled && "opacity-50 cursor-not-allowed"
)}
>
<input
type="radio"
name="auditMode"
value="fast"
checked={value === "fast"}
onChange={() => onChange("fast")}
disabled={disabled}
className="sr-only"
/>
<div className="flex items-center gap-2 mb-2">
<div className="p-1.5 bg-amber-100 border border-amber-300">
<Zap className="w-4 h-4 text-amber-600" />
</div>
<span className="font-bold text-sm"></span>
{value === "fast" && (
<CheckCircle2 className="w-4 h-4 text-amber-600 ml-auto" />
)}
</div>
<ul className="text-xs text-gray-600 space-y-1 mb-3">
<li className="flex items-center gap-1">
<Clock className="w-3 h-3" />
</li>
<li className="flex items-center gap-1">
<Code className="w-3 h-3" />
LLM
</li>
<li className="flex items-center gap-1 text-gray-400">
<Shield className="w-3 h-3" />
</li>
</ul>
<div className="mt-auto pt-2 border-t border-gray-200">
<span className="text-[10px] uppercase tracking-wider text-gray-500 font-bold">
适合: CI/CD
</span>
</div>
</label>
{/* Agent 审计模式 */}
<label
className={cn(
"relative flex flex-col p-4 border-2 cursor-pointer transition-all rounded-none",
value === "agent"
? "border-purple-500 bg-purple-50"
: "border-gray-300 hover:border-gray-400 bg-white",
disabled && "opacity-50 cursor-not-allowed"
)}
>
<input
type="radio"
name="auditMode"
value="agent"
checked={value === "agent"}
onChange={() => onChange("agent")}
disabled={disabled}
className="sr-only"
/>
{/* 推荐标签 */}
<div className="absolute -top-2 -right-2 px-2 py-0.5 bg-purple-600 text-white text-[10px] font-bold uppercase">
</div>
<div className="flex items-center gap-2 mb-2">
<div className="p-1.5 bg-purple-100 border border-purple-300">
<Bot className="w-4 h-4 text-purple-600" />
</div>
<span className="font-bold text-sm">Agent </span>
{value === "agent" && (
<CheckCircle2 className="w-4 h-4 text-purple-600 ml-auto" />
)}
</div>
<ul className="text-xs text-gray-600 space-y-1 mb-3">
<li className="flex items-center gap-1">
<Bot className="w-3 h-3" />
AI Agent
</li>
<li className="flex items-center gap-1">
<Code className="w-3 h-3" />
+ RAG
</li>
<li className="flex items-center gap-1 text-purple-600 font-medium">
<Shield className="w-3 h-3" />
</li>
</ul>
<div className="mt-auto pt-2 border-t border-gray-200">
<span className="text-[10px] uppercase tracking-wider text-gray-500 font-bold">
适合: 发版前审计
</span>
</div>
</label>
</div>
{/* 模式说明 */}
{value === "agent" && (
<div className="p-3 bg-purple-50 border border-purple-200 text-xs text-purple-800 rounded-none">
<p className="font-bold mb-1">🤖 Agent </p>
<ul className="list-disc list-inside space-y-0.5 text-purple-700">
<li>AI Agent </li>
<li>使 RAG </li>
<li> Docker </li>
<li> PoC</li>
<li></li>
</ul>
</div>
)}
</div>
);
}

View File

@ -0,0 +1,449 @@
/**
*
* LLM Agent RAG
*/
import { useState, useEffect } from "react";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Badge } from "@/components/ui/badge";
import { Separator } from "@/components/ui/separator";
import {
Brain,
Cpu,
Check,
X,
Loader2,
RefreshCw,
Server,
Key,
Zap,
Info,
} from "lucide-react";
import { toast } from "sonner";
import { apiClient } from "@/shared/api/serverClient";
interface EmbeddingProvider {
id: string;
name: string;
description: string;
models: string[];
requires_api_key: boolean;
default_model: string;
}
interface EmbeddingConfig {
provider: string;
model: string;
base_url: string | null;
dimensions: number;
batch_size: number;
}
interface TestResult {
success: boolean;
message: string;
dimensions?: number;
sample_embedding?: number[];
latency_ms?: number;
}
export default function EmbeddingConfigPanel() {
const [providers, setProviders] = useState<EmbeddingProvider[]>([]);
const [currentConfig, setCurrentConfig] = useState<EmbeddingConfig | null>(null);
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [testing, setTesting] = useState(false);
const [testResult, setTestResult] = useState<TestResult | null>(null);
// 表单状态
const [selectedProvider, setSelectedProvider] = useState("");
const [selectedModel, setSelectedModel] = useState("");
const [apiKey, setApiKey] = useState("");
const [baseUrl, setBaseUrl] = useState("");
const [batchSize, setBatchSize] = useState(100);
// 加载数据
useEffect(() => {
loadData();
}, []);
// 当 provider 改变时更新模型
useEffect(() => {
if (selectedProvider) {
const provider = providers.find((p) => p.id === selectedProvider);
if (provider) {
setSelectedModel(provider.default_model);
}
}
}, [selectedProvider, providers]);
const loadData = async () => {
try {
setLoading(true);
const [providersRes, configRes] = await Promise.all([
apiClient.get("/embedding/providers"),
apiClient.get("/embedding/config"),
]);
setProviders(providersRes.data);
setCurrentConfig(configRes.data);
// 设置表单默认值
if (configRes.data) {
setSelectedProvider(configRes.data.provider);
setSelectedModel(configRes.data.model);
setBaseUrl(configRes.data.base_url || "");
setBatchSize(configRes.data.batch_size);
}
} catch (error) {
toast.error("加载配置失败");
} finally {
setLoading(false);
}
};
const handleSave = async () => {
if (!selectedProvider || !selectedModel) {
toast.error("请选择提供商和模型");
return;
}
const provider = providers.find((p) => p.id === selectedProvider);
if (provider?.requires_api_key && !apiKey) {
toast.error(`${provider.name} 需要 API Key`);
return;
}
try {
setSaving(true);
await apiClient.put("/embedding/config", {
provider: selectedProvider,
model: selectedModel,
api_key: apiKey || undefined,
base_url: baseUrl || undefined,
batch_size: batchSize,
});
toast.success("配置已保存");
await loadData();
} catch (error: any) {
toast.error(error.response?.data?.detail || "保存失败");
} finally {
setSaving(false);
}
};
const handleTest = async () => {
if (!selectedProvider || !selectedModel) {
toast.error("请选择提供商和模型");
return;
}
try {
setTesting(true);
setTestResult(null);
const response = await apiClient.post("/embedding/test", {
provider: selectedProvider,
model: selectedModel,
api_key: apiKey || undefined,
base_url: baseUrl || undefined,
});
setTestResult(response.data);
if (response.data.success) {
toast.success("测试成功");
} else {
toast.error("测试失败");
}
} catch (error: any) {
setTestResult({
success: false,
message: error.response?.data?.detail || "测试失败",
});
toast.error("测试失败");
} finally {
setTesting(false);
}
};
const selectedProviderInfo = providers.find((p) => p.id === selectedProvider);
if (loading) {
return (
<div className="flex items-center justify-center p-8">
<Loader2 className="w-6 h-6 animate-spin" />
</div>
);
}
return (
<Card className="border-2 border-black rounded-none shadow-[4px_4px_0px_0px_rgba(0,0,0,1)]">
<CardHeader className="border-b-2 border-black bg-purple-50">
<div className="flex items-center gap-3">
<div className="p-2 bg-purple-100 border-2 border-purple-300">
<Brain className="w-5 h-5 text-purple-600" />
</div>
<div>
<CardTitle className="font-mono text-lg"></CardTitle>
<CardDescription>
Agent RAG LLM
</CardDescription>
</div>
</div>
</CardHeader>
<CardContent className="p-6 space-y-6">
{/* 当前配置状态 */}
{currentConfig && (
<div className="p-4 bg-gray-50 border-2 border-gray-200 space-y-2">
<div className="flex items-center gap-2 text-sm font-mono font-bold">
<Server className="w-4 h-4" />
</div>
<div className="grid grid-cols-2 gap-4 text-sm">
<div>
<span className="text-gray-500">:</span>{" "}
<Badge variant="outline" className="ml-1">
{currentConfig.provider}
</Badge>
</div>
<div>
<span className="text-gray-500">:</span>{" "}
<span className="font-mono">{currentConfig.model}</span>
</div>
<div>
<span className="text-gray-500">:</span>{" "}
<span className="font-mono">{currentConfig.dimensions}</span>
</div>
<div>
<span className="text-gray-500">:</span>{" "}
<span className="font-mono">{currentConfig.batch_size}</span>
</div>
</div>
</div>
)}
<Separator />
{/* 提供商选择 */}
<div className="space-y-2">
<Label className="font-mono font-bold"></Label>
<Select value={selectedProvider} onValueChange={setSelectedProvider}>
<SelectTrigger className="border-2 border-black rounded-none">
<SelectValue placeholder="选择提供商" />
</SelectTrigger>
<SelectContent className="border-2 border-black rounded-none">
{providers.map((provider) => (
<SelectItem key={provider.id} value={provider.id}>
<div className="flex items-center gap-2">
<span>{provider.name}</span>
{provider.requires_api_key ? (
<Key className="w-3 h-3 text-amber-500" />
) : (
<Cpu className="w-3 h-3 text-green-500" />
)}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
{selectedProviderInfo && (
<p className="text-xs text-gray-500 flex items-center gap-1">
<Info className="w-3 h-3" />
{selectedProviderInfo.description}
</p>
)}
</div>
{/* 模型选择 */}
{selectedProviderInfo && (
<div className="space-y-2">
<Label className="font-mono font-bold"></Label>
<Select value={selectedModel} onValueChange={setSelectedModel}>
<SelectTrigger className="border-2 border-black rounded-none">
<SelectValue placeholder="选择模型" />
</SelectTrigger>
<SelectContent className="border-2 border-black rounded-none">
{selectedProviderInfo.models.map((model) => (
<SelectItem key={model} value={model}>
<span className="font-mono text-sm">{model}</span>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
)}
{/* API Key */}
{selectedProviderInfo?.requires_api_key && (
<div className="space-y-2">
<Label className="font-mono font-bold">
API Key
<span className="text-red-500 ml-1">*</span>
</Label>
<Input
type="password"
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
placeholder="输入 API Key"
className="border-2 border-black rounded-none font-mono"
/>
<p className="text-xs text-gray-500">
API Key
</p>
</div>
)}
{/* 自定义端点 */}
<div className="space-y-2">
<Label className="font-mono font-bold">
API <span className="text-gray-400">()</span>
</Label>
<Input
type="url"
value={baseUrl}
onChange={(e) => setBaseUrl(e.target.value)}
placeholder={
selectedProvider === "ollama"
? "http://localhost:11434"
: selectedProvider === "huggingface"
? "https://router.huggingface.co"
: selectedProvider === "cohere"
? "https://api.cohere.com/v2"
: selectedProvider === "jina"
? "https://api.jina.ai/v1"
: "https://api.openai.com/v1"
}
className="border-2 border-black rounded-none font-mono"
/>
<p className="text-xs text-gray-500">
API
</p>
</div>
{/* 批处理大小 */}
<div className="space-y-2">
<Label className="font-mono font-bold"></Label>
<Input
type="number"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value) || 100)}
min={1}
max={500}
className="border-2 border-black rounded-none font-mono w-32"
/>
<p className="text-xs text-gray-500">
50-100
</p>
</div>
{/* 测试结果 */}
{testResult && (
<div
className={`p-4 border-2 ${
testResult.success
? "border-green-500 bg-green-50"
: "border-red-500 bg-red-50"
}`}
>
<div className="flex items-center gap-2 mb-2">
{testResult.success ? (
<Check className="w-5 h-5 text-green-600" />
) : (
<X className="w-5 h-5 text-red-600" />
)}
<span
className={`font-bold ${
testResult.success ? "text-green-700" : "text-red-700"
}`}
>
{testResult.success ? "测试成功" : "测试失败"}
</span>
</div>
<p className="text-sm">{testResult.message}</p>
{testResult.success && (
<div className="mt-2 text-xs text-gray-600 space-y-1">
<div>: {testResult.dimensions}</div>
<div>: {testResult.latency_ms}ms</div>
{testResult.sample_embedding && (
<div>
: [{testResult.sample_embedding.map((v) => v.toFixed(4)).join(", ")}...]
</div>
)}
</div>
)}
</div>
)}
{/* 操作按钮 */}
<div className="flex items-center gap-3 pt-4">
<Button
onClick={handleTest}
disabled={testing || !selectedProvider || !selectedModel}
variant="outline"
className="border-2 border-black rounded-none hover:bg-gray-100"
>
{testing ? (
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
) : (
<Zap className="w-4 h-4 mr-2" />
)}
</Button>
<Button
onClick={handleSave}
disabled={saving || !selectedProvider || !selectedModel}
className="bg-purple-600 hover:bg-purple-700 border-2 border-black rounded-none"
>
{saving ? (
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
) : (
<Check className="w-4 h-4 mr-2" />
)}
</Button>
<Button
onClick={loadData}
variant="ghost"
className="ml-auto"
>
<RefreshCw className="w-4 h-4" />
</Button>
</div>
{/* 说明 */}
<div className="p-4 bg-blue-50 border-l-4 border-blue-500 text-sm">
<p className="font-bold mb-1">💡 </p>
<ul className="list-disc list-inside text-gray-600 space-y-1">
<li> Agent (RAG)</li>
<li>使 LLM </li>
<li>使 OpenAI text-embedding-3-small Ollama</li>
<li></li>
</ul>
</div>
</CardContent>
</Card>
);
}

View File

@ -1,4 +1,5 @@
import { useState, useEffect, useMemo, useRef } from "react"; import { useState, useEffect, useMemo, useRef } from "react";
import { useNavigate } from "react-router-dom";
import { import {
Dialog, Dialog,
DialogContent, DialogContent,
@ -35,16 +36,19 @@ import {
Shield, Shield,
Loader2, Loader2,
Zap, Zap,
Bot,
} from "lucide-react"; } from "lucide-react";
import { toast } from "sonner"; import { toast } from "sonner";
import { api } from "@/shared/config/database"; import { api } from "@/shared/config/database";
import { getRuleSets, type AuditRuleSet } from "@/shared/api/rules"; import { getRuleSets, type AuditRuleSet } from "@/shared/api/rules";
import { getPromptTemplates, type PromptTemplate } from "@/shared/api/prompts"; import { getPromptTemplates, type PromptTemplate } from "@/shared/api/prompts";
import { createAgentTask } from "@/shared/api/agentTasks";
import { useProjects } from "./hooks/useTaskForm"; import { useProjects } from "./hooks/useTaskForm";
import { useZipFile, formatFileSize } from "./hooks/useZipFile"; import { useZipFile, formatFileSize } from "./hooks/useZipFile";
import TerminalProgressDialog from "./TerminalProgressDialog"; import TerminalProgressDialog from "./TerminalProgressDialog";
import FileSelectionDialog from "./FileSelectionDialog"; import FileSelectionDialog from "./FileSelectionDialog";
import AgentModeSelector, { type AuditMode } from "@/components/agent/AgentModeSelector";
import { runRepositoryAudit } from "@/features/projects/services/repoScan"; import { runRepositoryAudit } from "@/features/projects/services/repoScan";
import { import {
@ -76,6 +80,7 @@ export default function CreateTaskDialog({
onTaskCreated, onTaskCreated,
preselectedProjectId, preselectedProjectId,
}: CreateTaskDialogProps) { }: CreateTaskDialogProps) {
const navigate = useNavigate();
const [selectedProjectId, setSelectedProjectId] = useState<string>(""); const [selectedProjectId, setSelectedProjectId] = useState<string>("");
const [searchTerm, setSearchTerm] = useState(""); const [searchTerm, setSearchTerm] = useState("");
const [branch, setBranch] = useState("main"); const [branch, setBranch] = useState("main");
@ -90,6 +95,9 @@ export default function CreateTaskDialog({
const [showTerminal, setShowTerminal] = useState(false); const [showTerminal, setShowTerminal] = useState(false);
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null); const [currentTaskId, setCurrentTaskId] = useState<string | null>(null);
// 审计模式
const [auditMode, setAuditMode] = useState<AuditMode>("agent");
// 规则集和提示词模板 // 规则集和提示词模板
const [ruleSets, setRuleSets] = useState<AuditRuleSet[]>([]); const [ruleSets, setRuleSets] = useState<AuditRuleSet[]>([]);
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]); const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
@ -205,6 +213,31 @@ export default function CreateTaskDialog({
setCreating(true); setCreating(true);
let taskId: string; let taskId: string;
// Agent 审计模式
if (auditMode === "agent") {
const agentTask = await createAgentTask({
project_id: selectedProject.id,
name: `Agent审计-${selectedProject.name}`,
branch_name: isRepositoryProject(selectedProject) ? branch : undefined,
exclude_patterns: excludePatterns,
target_files: selectedFiles,
verification_level: "sandbox",
});
onOpenChange(false);
onTaskCreated();
toast.success("Agent 审计任务已创建");
// 导航到 Agent 审计页面
navigate(`/agent-audit/${agentTask.id}`);
setSelectedProjectId("");
setSelectedFiles(undefined);
setExcludePatterns(DEFAULT_EXCLUDES);
return;
}
// 快速审计模式(原有逻辑)
if (isZipProject(selectedProject)) { if (isZipProject(selectedProject)) {
if (zipState.useStoredZip && zipState.storedZipInfo?.has_file) { if (zipState.useStoredZip && zipState.storedZipInfo?.has_file) {
taskId = await scanStoredZipFile({ taskId = await scanStoredZipFile({
@ -339,6 +372,15 @@ export default function CreateTaskDialog({
</ScrollArea> </ScrollArea>
</div> </div>
{/* 审计模式选择 */}
{selectedProject && (
<AgentModeSelector
value={auditMode}
onChange={setAuditMode}
disabled={creating}
/>
)}
{/* 配置区域 */} {/* 配置区域 */}
{selectedProject && ( {selectedProject && (
<div className="space-y-4"> <div className="space-y-4">
@ -589,10 +631,15 @@ export default function CreateTaskDialog({
<div className="animate-spin h-4 w-4 border-2 border-white border-t-transparent mr-2" /> <div className="animate-spin h-4 w-4 border-2 border-white border-t-transparent mr-2" />
... ...
</> </>
) : auditMode === "agent" ? (
<>
<Bot className="w-4 h-4 mr-2" />
Agent
</>
) : ( ) : (
<> <>
<Play className="w-4 h-4 mr-2" /> <Zap className="w-4 h-4 mr-2" />
</> </>
)} )}
</Button> </Button>

View File

@ -6,10 +6,11 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { import {
Settings, Save, RotateCcw, Eye, EyeOff, CheckCircle2, AlertCircle, Settings, Save, RotateCcw, Eye, EyeOff, CheckCircle2, AlertCircle,
Info, Zap, Globe, PlayCircle, Loader2 Info, Zap, Globe, PlayCircle, Loader2, Brain
} from "lucide-react"; } from "lucide-react";
import { toast } from "sonner"; import { toast } from "sonner";
import { api } from "@/shared/api/database"; import { api } from "@/shared/api/database";
import EmbeddingConfig from "@/components/agent/EmbeddingConfig";
// LLM 提供商配置 - 2025年最新 // LLM 提供商配置 - 2025年最新
const LLM_PROVIDERS = [ const LLM_PROVIDERS = [
@ -246,10 +247,13 @@ export function SystemConfig() {
</div> </div>
<Tabs defaultValue="llm" className="w-full"> <Tabs defaultValue="llm" className="w-full">
<TabsList className="grid w-full grid-cols-3 bg-transparent border-2 border-black p-0 h-auto gap-0 mb-6"> <TabsList className="grid w-full grid-cols-4 bg-transparent border-2 border-black p-0 h-auto gap-0 mb-6">
<TabsTrigger value="llm" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs"> <TabsTrigger value="llm" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
<Zap className="w-3 h-3 mr-2" /> LLM <Zap className="w-3 h-3 mr-2" /> LLM
</TabsTrigger> </TabsTrigger>
<TabsTrigger value="embedding" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
<Brain className="w-3 h-3 mr-2" />
</TabsTrigger>
<TabsTrigger value="analysis" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs"> <TabsTrigger value="analysis" className="rounded-none border-r-2 border-black data-[state=active]:bg-black data-[state=active]:text-white font-mono font-bold uppercase h-10 text-xs">
<Settings className="w-3 h-3 mr-2" /> <Settings className="w-3 h-3 mr-2" />
</TabsTrigger> </TabsTrigger>
@ -388,6 +392,11 @@ export function SystemConfig() {
</div> </div>
</TabsContent> </TabsContent>
{/* 嵌入模型配置 */}
<TabsContent value="embedding" className="space-y-6">
<EmbeddingConfig />
</TabsContent>
{/* 分析参数 */} {/* 分析参数 */}
<TabsContent value="analysis" className="space-y-6"> <TabsContent value="analysis" className="space-y-6">
<div className="retro-card bg-white border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)] p-6 space-y-6"> <div className="retro-card bg-white border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)] p-6 space-y-6">

View File

@ -0,0 +1,579 @@
/**
* Agent
* AI Agent
*/
import { useState, useEffect, useRef, useCallback } from "react";
import { useParams, useNavigate } from "react-router-dom";
import {
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
ArrowLeft, Download, ExternalLink
} from "lucide-react";
import { Button } from "@/components/ui/button";
import { Badge } from "@/components/ui/badge";
import { ScrollArea } from "@/components/ui/scroll-area";
import { toast } from "sonner";
import {
type AgentTask,
type AgentEvent,
type AgentFinding,
getAgentTask,
getAgentEvents,
getAgentFindings,
cancelAgentTask,
streamAgentEvents,
} from "@/shared/api/agentTasks";
// 事件类型图标映射
const eventTypeIcons: Record<string, React.ReactNode> = {
phase_start: <Zap className="w-3 h-3 text-cyan-400" />,
phase_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
thinking: <Cpu className="w-3 h-3 text-purple-400" />,
tool_call: <Code className="w-3 h-3 text-yellow-400" />,
tool_result: <CheckCircle2 className="w-3 h-3 text-green-400" />,
tool_error: <XCircle className="w-3 h-3 text-red-400" />,
finding_new: <Bug className="w-3 h-3 text-orange-400" />,
finding_verified: <Shield className="w-3 h-3 text-red-400" />,
info: <Activity className="w-3 h-3 text-blue-400" />,
warning: <AlertTriangle className="w-3 h-3 text-yellow-400" />,
error: <XCircle className="w-3 h-3 text-red-500" />,
progress: <RefreshCw className="w-3 h-3 text-cyan-400 animate-spin" />,
task_complete: <CheckCircle2 className="w-3 h-3 text-green-500" />,
task_error: <XCircle className="w-3 h-3 text-red-500" />,
task_cancel: <Square className="w-3 h-3 text-yellow-500" />,
};
// 事件类型颜色映射
const eventTypeColors: Record<string, string> = {
phase_start: "text-cyan-400 font-bold",
phase_complete: "text-green-400",
thinking: "text-purple-300",
tool_call: "text-yellow-300",
tool_result: "text-green-300",
tool_error: "text-red-400",
finding_new: "text-orange-300",
finding_verified: "text-red-300",
info: "text-gray-300",
warning: "text-yellow-300",
error: "text-red-400",
progress: "text-cyan-300",
task_complete: "text-green-400 font-bold",
task_error: "text-red-400 font-bold",
task_cancel: "text-yellow-400",
};
// 严重程度颜色
const severityColors: Record<string, string> = {
critical: "bg-red-900/50 border-red-500 text-red-300",
high: "bg-orange-900/50 border-orange-500 text-orange-300",
medium: "bg-yellow-900/50 border-yellow-500 text-yellow-300",
low: "bg-blue-900/50 border-blue-500 text-blue-300",
info: "bg-gray-900/50 border-gray-500 text-gray-300",
};
const severityIcons: Record<string, string> = {
critical: "🔴",
high: "🟠",
medium: "🟡",
low: "🟢",
info: "⚪",
};
export default function AgentAuditPage() {
const { taskId } = useParams<{ taskId: string }>();
const navigate = useNavigate();
const [task, setTask] = useState<AgentTask | null>(null);
const [events, setEvents] = useState<AgentEvent[]>([]);
const [findings, setFindings] = useState<AgentFinding[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [isStreaming, setIsStreaming] = useState(false);
const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
const eventsEndRef = useRef<HTMLDivElement>(null);
const abortControllerRef = useRef<AbortController | null>(null);
// 是否完成
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
// 加载任务信息
const loadTask = useCallback(async () => {
if (!taskId) return;
try {
const taskData = await getAgentTask(taskId);
setTask(taskData);
} catch (error) {
console.error("Failed to load task:", error);
toast.error("加载任务失败");
}
}, [taskId]);
// 加载事件
const loadEvents = useCallback(async () => {
if (!taskId) return;
try {
const eventsData = await getAgentEvents(taskId, { limit: 500 });
setEvents(eventsData);
} catch (error) {
console.error("Failed to load events:", error);
}
}, [taskId]);
// 加载发现
const loadFindings = useCallback(async () => {
if (!taskId) return;
try {
const findingsData = await getAgentFindings(taskId);
setFindings(findingsData);
} catch (error) {
console.error("Failed to load findings:", error);
}
}, [taskId]);
// 初始化加载
useEffect(() => {
const init = async () => {
setIsLoading(true);
await Promise.all([loadTask(), loadEvents(), loadFindings()]);
setIsLoading(false);
};
init();
}, [loadTask, loadEvents, loadFindings]);
// 事件流
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
const startStreaming = async () => {
setIsStreaming(true);
abortControllerRef.current = new AbortController();
try {
const lastSequence = events.length > 0 ? Math.max(...events.map(e => e.sequence)) : 0;
for await (const event of streamAgentEvents(taskId, lastSequence, abortControllerRef.current.signal)) {
setEvents(prev => {
// 避免重复
if (prev.some(e => e.id === event.id)) return prev;
return [...prev, event];
});
// 如果是发现事件,刷新发现列表
if (event.event_type.startsWith("finding_")) {
loadFindings();
}
// 如果是结束事件,刷新任务状态
if (["task_complete", "task_error", "task_cancel"].includes(event.event_type)) {
loadTask();
loadFindings();
}
}
} catch (error) {
if ((error as Error).name !== "AbortError") {
console.error("Event stream error:", error);
}
} finally {
setIsStreaming(false);
}
};
startStreaming();
return () => {
abortControllerRef.current?.abort();
};
}, [taskId, isComplete, isLoading, loadTask, loadFindings]);
// 自动滚动
useEffect(() => {
eventsEndRef.current?.scrollIntoView({ behavior: "smooth" });
}, [events]);
// 更新时间
useEffect(() => {
const interval = setInterval(() => {
setCurrentTime(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
}, 1000);
return () => clearInterval(interval);
}, []);
// 取消任务
const handleCancel = async () => {
if (!taskId) return;
if (!confirm("确定要取消此任务吗?已分析的结果将被保留。")) {
return;
}
try {
await cancelAgentTask(taskId);
toast.success("任务已取消");
loadTask();
} catch (error) {
toast.error("取消失败");
}
};
if (isLoading) {
return (
<div className="min-h-screen bg-[#0a0a0f] flex items-center justify-center">
<div className="text-center">
<Loader2 className="w-12 h-12 text-cyan-400 animate-spin mx-auto mb-4" />
<p className="text-gray-400 font-mono">...</p>
</div>
</div>
);
}
if (!task) {
return (
<div className="min-h-screen bg-[#0a0a0f] flex items-center justify-center">
<div className="text-center">
<XCircle className="w-12 h-12 text-red-400 mx-auto mb-4" />
<p className="text-gray-400 font-mono"></p>
<Button
variant="outline"
className="mt-4"
onClick={() => navigate("/tasks")}
>
</Button>
</div>
</div>
);
}
return (
<div className="min-h-screen bg-[#0a0a0f] text-white font-mono">
{/* 顶部状态栏 */}
<div className="h-14 bg-[#12121a] border-b-2 border-cyan-900/50 flex items-center justify-between px-6">
<div className="flex items-center gap-4">
<Button
variant="ghost"
size="sm"
onClick={() => navigate(-1)}
className="text-gray-400 hover:text-white"
>
<ArrowLeft className="w-4 h-4 mr-1" />
</Button>
<div className="flex items-center gap-3">
<div className="p-1.5 bg-cyan-900/30 rounded border border-cyan-700/50">
<Bot className={`w-5 h-5 text-cyan-400 ${!isComplete && "animate-pulse"}`} />
</div>
<div>
<span className="text-xs text-gray-500 block">AGENT_AUDIT</span>
<span className="text-sm font-bold tracking-wider text-cyan-400">
{task.name || `任务 ${task.id.slice(0, 8)}`}
</span>
</div>
</div>
</div>
<div className="flex items-center gap-6">
{/* 阶段指示器 */}
<PhaseIndicator phase={task.current_phase} status={task.status} />
{/* 状态徽章 */}
<StatusBadge status={task.status} />
{/* 时间 */}
<span className="text-gray-500 text-sm">{currentTime}</span>
</div>
</div>
<div className="flex h-[calc(100vh-56px)]">
{/* 左侧:执行日志 */}
<div className="flex-1 p-4 flex flex-col min-w-0">
<div className="flex items-center justify-between mb-3">
<div className="flex items-center gap-2 text-xs text-cyan-400">
<Terminal className="w-4 h-4" />
<span className="uppercase tracking-wider">Execution Log</span>
{isStreaming && (
<span className="flex items-center gap-1 text-green-400">
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
LIVE
</span>
)}
</div>
<span className="text-xs text-gray-500">{events.length} events</span>
</div>
{/* 终端窗口 */}
<div className="flex-1 bg-[#0d0d12] rounded-lg border border-gray-800 overflow-hidden relative">
{/* CRT 效果 */}
<div className="absolute inset-0 pointer-events-none z-10 opacity-[0.03]"
style={{
backgroundImage: "repeating-linear-gradient(0deg, transparent, transparent 1px, rgba(0, 255, 255, 0.03) 1px, rgba(0, 255, 255, 0.03) 2px)",
}}
/>
<ScrollArea className="h-full">
<div className="p-4 space-y-1">
{events.map((event) => (
<EventLine key={event.id} event={event} />
))}
{/* 光标 */}
{!isComplete && (
<div className="flex items-center gap-2 mt-2">
<span className="text-gray-600 text-xs">{currentTime}</span>
<span className="text-cyan-400 animate-pulse"></span>
</div>
)}
<div ref={eventsEndRef} />
</div>
</ScrollArea>
</div>
{/* 底部控制栏 */}
<div className="mt-3 flex items-center justify-between">
<div className="flex items-center gap-4">
{/* 进度 */}
<div className="flex items-center gap-2">
<span className="text-xs text-gray-500">Progress</span>
<div className="w-32 h-2 bg-gray-800 rounded-full overflow-hidden">
<div
className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 transition-all duration-300"
style={{ width: `${task.progress_percentage}%` }}
/>
</div>
<span className="text-xs text-cyan-400">{task.progress_percentage.toFixed(0)}%</span>
</div>
{/* Token 消耗 */}
{task.total_chunks > 0 && (
<div className="text-xs text-gray-500">
Chunks: <span className="text-gray-300">{task.total_chunks}</span>
</div>
)}
</div>
<div className="flex items-center gap-2">
{!isComplete && (
<Button
size="sm"
variant="outline"
onClick={handleCancel}
className="h-8 bg-transparent border-red-800 text-red-400 hover:bg-red-900/30 font-mono text-xs"
>
<Square className="w-3 h-3 mr-1" />
</Button>
)}
{isComplete && (
<Button
size="sm"
variant="outline"
onClick={() => navigate(`/tasks/${taskId}`)}
className="h-8 bg-transparent border-cyan-800 text-cyan-400 hover:bg-cyan-900/30 font-mono text-xs"
>
<ExternalLink className="w-3 h-3 mr-1" />
</Button>
)}
</div>
</div>
</div>
{/* 右侧:发现面板 */}
<div className="w-80 bg-[#12121a] border-l border-gray-800 flex flex-col">
<div className="p-4 border-b border-gray-800">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2 text-xs text-red-400">
<Shield className="w-4 h-4" />
<span className="uppercase tracking-wider">Findings</span>
</div>
<Badge variant="outline" className="bg-red-900/30 border-red-700 text-red-400">
{findings.length}
</Badge>
</div>
{/* 严重程度统计 */}
<div className="flex items-center gap-3 mt-3 text-xs">
{task.critical_count > 0 && (
<span className="text-red-400">🔴 {task.critical_count}</span>
)}
{task.high_count > 0 && (
<span className="text-orange-400">🟠 {task.high_count}</span>
)}
{task.medium_count > 0 && (
<span className="text-yellow-400">🟡 {task.medium_count}</span>
)}
{task.low_count > 0 && (
<span className="text-blue-400">🟢 {task.low_count}</span>
)}
</div>
</div>
<ScrollArea className="flex-1">
<div className="p-3 space-y-2">
{findings.length === 0 ? (
<div className="text-center text-gray-500 py-8">
<Search className="w-8 h-8 mx-auto mb-2 opacity-50" />
<p className="text-sm"></p>
</div>
) : (
findings.map((finding) => (
<FindingCard key={finding.id} finding={finding} />
))
)}
</div>
</ScrollArea>
{/* 评分 */}
{isComplete && (
<div className="p-4 border-t border-gray-800 space-y-2">
<div className="flex items-center justify-between text-xs">
<span className="text-gray-500"></span>
<span className={`font-bold ${
task.security_score >= 80 ? "text-green-400" :
task.security_score >= 60 ? "text-yellow-400" :
"text-red-400"
}`}>
{task.security_score.toFixed(0)}/100
</span>
</div>
<div className="flex items-center justify-between text-xs">
<span className="text-gray-500"></span>
<span className="text-cyan-400">{task.verified_count}/{task.findings_count}</span>
</div>
</div>
)}
</div>
</div>
</div>
);
}
// 阶段指示器组件
function PhaseIndicator({ phase, status }: { phase: string | null; status: string }) {
const phases = ["planning", "indexing", "analysis", "verification", "reporting"];
const currentIndex = phase ? phases.indexOf(phase) : -1;
const isComplete = status === "completed";
const isFailed = status === "failed";
return (
<div className="flex items-center gap-1">
{phases.map((p, idx) => {
const isActive = p === phase;
const isPast = isComplete || (currentIndex >= 0 && idx < currentIndex);
return (
<div
key={p}
className={`w-2 h-2 rounded-full transition-all ${
isActive ? "bg-cyan-400 shadow-[0_0_8px_rgba(34,211,238,0.6)] animate-pulse" :
isPast ? "bg-cyan-600" :
isFailed ? "bg-red-900" :
"bg-gray-700"
}`}
title={p}
/>
);
})}
{phase && (
<span className="ml-2 text-xs text-gray-400 uppercase">{phase}</span>
)}
</div>
);
}
// 状态徽章组件
function StatusBadge({ status }: { status: string }) {
const statusConfig: Record<string, { text: string; className: string }> = {
pending: { text: "PENDING", className: "bg-gray-800 text-gray-400 border-gray-600" },
initializing: { text: "INIT", className: "bg-blue-900/50 text-blue-400 border-blue-600 animate-pulse" },
planning: { text: "PLANNING", className: "bg-purple-900/50 text-purple-400 border-purple-600 animate-pulse" },
indexing: { text: "INDEXING", className: "bg-cyan-900/50 text-cyan-400 border-cyan-600 animate-pulse" },
analyzing: { text: "ANALYZING", className: "bg-yellow-900/50 text-yellow-400 border-yellow-600 animate-pulse" },
verifying: { text: "VERIFYING", className: "bg-orange-900/50 text-orange-400 border-orange-600 animate-pulse" },
completed: { text: "COMPLETED", className: "bg-green-900/50 text-green-400 border-green-600" },
failed: { text: "FAILED", className: "bg-red-900/50 text-red-400 border-red-600" },
cancelled: { text: "CANCELLED", className: "bg-yellow-900/50 text-yellow-400 border-yellow-600" },
};
const config = statusConfig[status] || statusConfig.pending;
return (
<Badge variant="outline" className={`${config.className} font-mono text-xs px-2`}>
{config.text}
</Badge>
);
}
// 事件行组件
function EventLine({ event }: { event: AgentEvent }) {
const icon = eventTypeIcons[event.event_type] || <ChevronRight className="w-3 h-3 text-gray-500" />;
const colorClass = eventTypeColors[event.event_type] || "text-gray-400";
const timestamp = event.timestamp
? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false })
: "";
return (
<div className={`flex items-start gap-2 py-0.5 group hover:bg-white/5 px-1 rounded ${colorClass}`}>
<span className="text-gray-600 text-xs w-20 flex-shrink-0 group-hover:text-gray-500">
{timestamp}
</span>
<span className="flex-shrink-0 mt-0.5">{icon}</span>
<span className="flex-1 text-sm break-all">
{event.message}
{event.tool_duration_ms && (
<span className="text-gray-600 ml-2">({event.tool_duration_ms}ms)</span>
)}
</span>
</div>
);
}
// 发现卡片组件
function FindingCard({ finding }: { finding: AgentFinding }) {
const colorClass = severityColors[finding.severity] || severityColors.info;
const icon = severityIcons[finding.severity] || "⚪";
return (
<div className={`p-3 rounded border-l-4 ${colorClass} transition-all hover:brightness-110`}>
<div className="flex items-start gap-2">
<span>{icon}</span>
<div className="flex-1 min-w-0">
<p className="text-sm font-medium truncate">{finding.title}</p>
<p className="text-xs text-gray-400 mt-0.5">{finding.vulnerability_type}</p>
{finding.file_path && (
<p className="text-xs text-gray-500 mt-1 truncate" title={finding.file_path}>
<FileCode className="w-3 h-3 inline mr-1" />
{finding.file_path}:{finding.line_start}
</p>
)}
</div>
</div>
<div className="flex items-center gap-2 mt-2">
{finding.is_verified && (
<Badge variant="outline" className="text-[10px] px-1 py-0 bg-green-900/30 border-green-700 text-green-400">
<CheckCircle2 className="w-2.5 h-2.5 mr-0.5" />
</Badge>
)}
{finding.has_poc && (
<Badge variant="outline" className="text-[10px] px-1 py-0 bg-red-900/30 border-red-700 text-red-400">
<Code className="w-2.5 h-2.5 mr-0.5" />
PoC
</Badge>
)}
</div>
</div>
);
}

View File

@ -0,0 +1,306 @@
/**
* Agent Tasks API
* Agent API
*/
import { apiClient } from "./serverClient";
// ============ Types ============
export interface AgentTask {
id: string;
project_id: string;
name: string | null;
description: string | null;
task_type: string;
status: string;
current_phase: string | null;
current_step: string | null;
// 统计
total_files: number;
indexed_files: number;
analyzed_files: number;
total_chunks: number;
findings_count: number;
verified_count: number;
false_positive_count: number;
// 严重程度统计
critical_count: number;
high_count: number;
medium_count: number;
low_count: number;
// 评分
quality_score: number;
security_score: number;
// 时间
created_at: string;
started_at: string | null;
completed_at: string | null;
// 进度
progress_percentage: number;
}
export interface AgentFinding {
id: string;
task_id: string;
vulnerability_type: string;
severity: string;
title: string;
description: string | null;
file_path: string | null;
line_start: number | null;
line_end: number | null;
code_snippet: string | null;
status: string;
is_verified: boolean;
has_poc: boolean;
poc_code: string | null;
suggestion: string | null;
fix_code: string | null;
ai_explanation: string | null;
ai_confidence: number | null;
created_at: string;
}
export interface AgentEvent {
id: string;
task_id: string;
event_type: string;
phase: string | null;
message: string | null;
tool_name: string | null;
tool_input?: Record<string, unknown>;
tool_output?: Record<string, unknown>;
tool_duration_ms: number | null;
finding_id: string | null;
tokens_used?: number;
metadata?: Record<string, unknown>;
sequence: number;
timestamp: string;
}
export interface CreateAgentTaskRequest {
project_id: string;
name?: string;
description?: string;
audit_scope?: Record<string, unknown>;
target_vulnerabilities?: string[];
verification_level?: "analysis_only" | "sandbox" | "generate_poc";
branch_name?: string;
exclude_patterns?: string[];
target_files?: string[];
max_iterations?: number;
token_budget?: number;
timeout_seconds?: number;
}
export interface AgentTaskSummary {
task_id: string;
status: string;
progress_percentage: number;
security_score: number;
quality_score: number;
statistics: {
total_files: number;
indexed_files: number;
analyzed_files: number;
total_chunks: number;
findings_count: number;
verified_count: number;
false_positive_count: number;
};
severity_distribution: {
critical: number;
high: number;
medium: number;
low: number;
};
vulnerability_types: Record<string, { total: number; verified: number }>;
duration_seconds: number | null;
}
// ============ API Functions ============
/**
* Agent
*/
export async function createAgentTask(data: CreateAgentTaskRequest): Promise<AgentTask> {
const response = await apiClient.post("/agent-tasks/", data);
return response.data;
}
/**
* Agent
*/
export async function getAgentTasks(params?: {
project_id?: string;
status?: string;
skip?: number;
limit?: number;
}): Promise<AgentTask[]> {
const response = await apiClient.get("/agent-tasks/", { params });
return response.data;
}
/**
* Agent
*/
export async function getAgentTask(taskId: string): Promise<AgentTask> {
const response = await apiClient.get(`/agent-tasks/${taskId}`);
return response.data;
}
/**
* Agent
*/
export async function startAgentTask(taskId: string): Promise<{ message: string; task_id: string }> {
const response = await apiClient.post(`/agent-tasks/${taskId}/start`);
return response.data;
}
/**
* Agent
*/
export async function cancelAgentTask(taskId: string): Promise<{ message: string; task_id: string }> {
const response = await apiClient.post(`/agent-tasks/${taskId}/cancel`);
return response.data;
}
/**
* Agent
*/
export async function getAgentEvents(
taskId: string,
params?: { after_sequence?: number; limit?: number }
): Promise<AgentEvent[]> {
const response = await apiClient.get(`/agent-tasks/${taskId}/events/list`, { params });
return response.data;
}
/**
* Agent
*/
export async function getAgentFindings(
taskId: string,
params?: {
severity?: string;
vulnerability_type?: string;
is_verified?: boolean;
}
): Promise<AgentFinding[]> {
const response = await apiClient.get(`/agent-tasks/${taskId}/findings`, { params });
return response.data;
}
/**
*
*/
export async function getAgentFinding(taskId: string, findingId: string): Promise<AgentFinding> {
const response = await apiClient.get(`/agent-tasks/${taskId}/findings/${findingId}`);
return response.data;
}
/**
*
*/
export async function updateAgentFinding(
taskId: string,
findingId: string,
data: { status?: string }
): Promise<AgentFinding> {
const response = await apiClient.patch(`/agent-tasks/${taskId}/findings/${findingId}`, data);
return response.data;
}
/**
*
*/
export async function getAgentTaskSummary(taskId: string): Promise<AgentTaskSummary> {
const response = await apiClient.get(`/agent-tasks/${taskId}/summary`);
return response.data;
}
/**
* SSE
*/
export function createAgentEventSource(taskId: string, afterSequence = 0): EventSource {
const baseUrl = import.meta.env.VITE_API_URL || "";
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
// 注意EventSource 不支持自定义 headers需要通过 URL 参数或 cookie 传递认证
// 如果需要认证,可以考虑使用 fetch + ReadableStream 替代
return new EventSource(url, { withCredentials: true });
}
/**
* 使 fetch headers
*/
export async function* streamAgentEvents(
taskId: string,
afterSequence = 0,
signal?: AbortSignal
): AsyncGenerator<AgentEvent, void, unknown> {
const token = localStorage.getItem("auth_token");
const baseUrl = import.meta.env.VITE_API_URL || "";
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
const response = await fetch(url, {
headers: {
Authorization: `Bearer ${token}`,
Accept: "text/event-stream",
},
signal,
});
if (!response.ok) {
throw new Error(`Failed to connect to event stream: ${response.statusText}`);
}
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
const decoder = new TextDecoder();
let buffer = "";
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
// 解析 SSE 格式
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6);
try {
const event = JSON.parse(data) as AgentEvent;
yield event;
} catch {
// 忽略解析错误
}
}
}
}
} finally {
reader.releaseLock();
}
}