CodeReview/backend/app/services/agent/knowledge/tools.py

319 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识查询工具 - 让Agent可以在运行时查询安全知识
基于RAG的知识检索工具
"""
import logging
from typing import Dict, Any, Optional, List, Type
from pydantic import BaseModel, Field
from ..tools.base import AgentTool, ToolResult
from .rag_knowledge import security_knowledge_rag, KnowledgeCategory
logger = logging.getLogger(__name__)
class SecurityKnowledgeQueryInput(BaseModel):
"""安全知识查询输入"""
query: str = Field(..., description="搜索查询,如漏洞类型、技术名称、安全概念等")
category: Optional[str] = Field(
None,
description="知识类别过滤: vulnerability, best_practice, remediation, code_pattern, compliance"
)
top_k: int = Field(3, description="返回结果数量", ge=1, le=10)
class SecurityKnowledgeQueryTool(AgentTool):
"""
安全知识查询工具
用于查询安全漏洞知识、最佳实践、修复建议等
"""
@property
def name(self) -> str:
return "query_security_knowledge"
@property
def description(self) -> str:
return """查询安全知识库,获取漏洞类型、检测方法、修复建议等专业知识。
使用场景:
- 需要了解某种漏洞类型的详细信息
- 查找安全最佳实践
- 获取修复建议
- 了解特定技术的安全考量
示例查询:
- "SQL injection detection methods"
- "XSS prevention best practices"
- "SSRF vulnerability patterns"
- "hardcoded credentials"
"""
@property
def args_schema(self) -> Type[BaseModel]:
return SecurityKnowledgeQueryInput
async def _execute(
self,
query: str,
category: Optional[str] = None,
top_k: int = 3,
) -> ToolResult:
"""执行知识查询"""
try:
# 转换类别
knowledge_category = None
if category:
try:
knowledge_category = KnowledgeCategory(category.lower())
except ValueError:
pass
# 执行搜索
results = await security_knowledge_rag.search(
query=query,
category=knowledge_category,
top_k=top_k,
)
if not results:
return ToolResult(
success=True,
data="未找到相关的安全知识。请尝试使用不同的关键词。",
metadata={"query": query, "results_count": 0},
)
# 格式化结果
formatted_results = []
for i, result in enumerate(results, 1):
formatted = f"### 结果 {i}"
if result.get("title"):
formatted += f": {result['title']}"
formatted += f"\n相关度: {result.get('score', 0):.2f}\n"
if result.get("tags"):
formatted += f"标签: {', '.join(result['tags'])}\n"
if result.get("cwe_ids"):
formatted += f"CWE: {', '.join(result['cwe_ids'])}\n"
formatted += f"\n{result.get('content', '')}"
formatted_results.append(formatted)
output = f"找到 {len(results)} 条相关知识:\n\n" + "\n\n---\n\n".join(formatted_results)
return ToolResult(
success=True,
data=output,
metadata={
"query": query,
"results_count": len(results),
"results": results,
},
)
except Exception as e:
logger.error(f"Knowledge query failed: {e}")
return ToolResult(
success=False,
error=f"知识查询失败: {str(e)}",
)
class VulnerabilityKnowledgeInput(BaseModel):
"""漏洞知识查询输入"""
vulnerability_type: str = Field(
...,
description="漏洞类型,如: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, hardcoded_secrets, auth_bypass"
)
project_language: Optional[str] = Field(
None,
description="目标项目的主要编程语言(如 python, php, javascript, rust, go用于过滤相关示例"
)
class GetVulnerabilityKnowledgeTool(AgentTool):
"""
获取特定漏洞类型的完整知识
返回该漏洞类型的检测方法、危险模式、修复建议等完整信息
"""
@property
def name(self) -> str:
return "get_vulnerability_knowledge"
@property
def description(self) -> str:
return """获取特定漏洞类型的完整专业知识。
支持的漏洞类型:
- sql_injection: SQL注入
- xss: 跨站脚本攻击
- command_injection: 命令注入
- path_traversal: 路径遍历
- ssrf: 服务端请求伪造
- deserialization: 不安全的反序列化
- hardcoded_secrets: 硬编码凭证
- auth_bypass: 认证绕过
返回内容包括:
- 漏洞概述和危害
- 危险代码模式
- 检测方法
- 安全实践
- 修复示例
"""
@property
def args_schema(self) -> Type[BaseModel]:
return VulnerabilityKnowledgeInput
async def _execute(self, vulnerability_type: str, project_language: Optional[str] = None) -> ToolResult:
"""获取漏洞知识"""
try:
knowledge = await security_knowledge_rag.get_vulnerability_knowledge(
vulnerability_type
)
if not knowledge:
available = security_knowledge_rag.get_all_vulnerability_types()
return ToolResult(
success=True,
data=f"未找到漏洞类型 '{vulnerability_type}' 的知识。\n\n可用的漏洞类型: {', '.join(available)}",
metadata={"available_types": available},
)
# 格式化输出
output_parts = [
f"# {knowledge.get('title', vulnerability_type)}",
f"严重程度: {knowledge.get('severity', 'N/A')}",
]
if knowledge.get("cwe_ids"):
output_parts.append(f"CWE: {', '.join(knowledge['cwe_ids'])}")
if knowledge.get("owasp_ids"):
output_parts.append(f"OWASP: {', '.join(knowledge['owasp_ids'])}")
# 🔥 v2.2: 添加语言不匹配警告
content = knowledge.get("content", "")
knowledge_lang = self._detect_code_language(content)
if project_language and knowledge_lang:
project_lang_lower = project_language.lower()
if knowledge_lang.lower() != project_lang_lower:
output_parts.append("")
output_parts.append("=" * 60)
output_parts.append(f"⚠️ **重要警告**: 以下示例代码是 {knowledge_lang.upper()} 语言")
output_parts.append(f" 你正在审计的项目是 {project_language.upper()} 项目")
output_parts.append(" **这些代码示例仅供概念参考,不要直接套用到目标项目!**")
output_parts.append(" 请在目标项目中查找该语言特有的等效漏洞模式。")
output_parts.append("=" * 60)
output_parts.append("")
output_parts.append(content)
# 🔥 v2.2: 添加使用指南
output_parts.append("")
output_parts.append("---")
output_parts.append("📌 **使用指南**:")
output_parts.append("1. 以上知识仅供参考,你必须在实际代码中验证漏洞是否存在")
output_parts.append("2. 不要假设项目中存在示例中的代码模式")
output_parts.append("3. 只有在 read_file 读取到的代码中确实存在问题时才报告漏洞")
output_parts.append("4. 如果示例语言与项目语言不同,请查找该语言的等效漏洞模式")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
**knowledge,
"knowledge_language": knowledge_lang,
"project_language": project_language,
},
)
except Exception as e:
logger.error(f"Get vulnerability knowledge failed: {e}")
return ToolResult(
success=False,
error=f"获取漏洞知识失败: {str(e)}",
)
def _detect_code_language(self, content: str) -> Optional[str]:
"""检测知识内容中的主要代码语言"""
# 检测代码块中的语言标记
import re
code_blocks = re.findall(r'```(\w+)', content)
if code_blocks:
# 统计最常见的语言
from collections import Counter
lang_counts = Counter(code_blocks)
most_common = lang_counts.most_common(1)
if most_common:
return most_common[0][0]
# 基于内容特征检测
if "def " in content or "import " in content or "@app.route" in content:
return "python"
if "<?php" in content or "$_GET" in content or "$_POST" in content:
return "php"
if "function " in content and ("const " in content or "let " in content):
return "javascript"
if "func " in content and "package " in content:
return "go"
if "fn " in content and "let mut" in content:
return "rust"
if "public class" in content or "private void" in content:
return "java"
return None
class ListKnowledgeModulesInput(BaseModel):
"""列出知识模块输入"""
category: Optional[str] = Field(
None,
description="按类别过滤: vulnerability, best_practice, remediation"
)
class ListKnowledgeModulesTool(AgentTool):
"""
列出所有可用的知识模块
"""
@property
def name(self) -> str:
return "list_knowledge_modules"
@property
def description(self) -> str:
return "列出所有可用的安全知识模块,包括漏洞类型、最佳实践等"
@property
def args_schema(self) -> Type[BaseModel]:
return ListKnowledgeModulesInput
async def _execute(self, category: Optional[str] = None) -> ToolResult:
"""列出知识模块"""
try:
modules = security_knowledge_rag.get_all_vulnerability_types()
output = "可用的安全知识模块:\n\n"
output += "## 漏洞类型\n"
for module in modules:
output += f"- {module}\n"
return ToolResult(
success=True,
data=output,
metadata={"modules": modules},
)
except Exception as e:
logger.error(f"List knowledge modules failed: {e}")
return ToolResult(
success=False,
error=f"列出知识模块失败: {str(e)}",
)