258 lines
8.0 KiB
Python
258 lines
8.0 KiB
Python
|
|
"""
|
|||
|
|
知识查询工具 - 让Agent可以在运行时查询安全知识
|
|||
|
|
|
|||
|
|
基于RAG的知识检索工具
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, Optional, List, Type
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
|
|||
|
|
from ..tools.base import AgentTool, ToolResult
|
|||
|
|
from .rag_knowledge import security_knowledge_rag, KnowledgeCategory
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SecurityKnowledgeQueryInput(BaseModel):
|
|||
|
|
"""安全知识查询输入"""
|
|||
|
|
query: str = Field(..., description="搜索查询,如漏洞类型、技术名称、安全概念等")
|
|||
|
|
category: Optional[str] = Field(
|
|||
|
|
None,
|
|||
|
|
description="知识类别过滤: vulnerability, best_practice, remediation, code_pattern, compliance"
|
|||
|
|
)
|
|||
|
|
top_k: int = Field(3, description="返回结果数量", ge=1, le=10)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SecurityKnowledgeQueryTool(AgentTool):
|
|||
|
|
"""
|
|||
|
|
安全知识查询工具
|
|||
|
|
|
|||
|
|
用于查询安全漏洞知识、最佳实践、修复建议等
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def name(self) -> str:
|
|||
|
|
return "query_security_knowledge"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def description(self) -> str:
|
|||
|
|
return """查询安全知识库,获取漏洞类型、检测方法、修复建议等专业知识。
|
|||
|
|
|
|||
|
|
使用场景:
|
|||
|
|
- 需要了解某种漏洞类型的详细信息
|
|||
|
|
- 查找安全最佳实践
|
|||
|
|
- 获取修复建议
|
|||
|
|
- 了解特定技术的安全考量
|
|||
|
|
|
|||
|
|
示例查询:
|
|||
|
|
- "SQL injection detection methods"
|
|||
|
|
- "XSS prevention best practices"
|
|||
|
|
- "SSRF vulnerability patterns"
|
|||
|
|
- "hardcoded credentials"
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self) -> Type[BaseModel]:
|
|||
|
|
return SecurityKnowledgeQueryInput
|
|||
|
|
|
|||
|
|
async def _execute(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
category: Optional[str] = None,
|
|||
|
|
top_k: int = 3,
|
|||
|
|
) -> ToolResult:
|
|||
|
|
"""执行知识查询"""
|
|||
|
|
try:
|
|||
|
|
# 转换类别
|
|||
|
|
knowledge_category = None
|
|||
|
|
if category:
|
|||
|
|
try:
|
|||
|
|
knowledge_category = KnowledgeCategory(category.lower())
|
|||
|
|
except ValueError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 执行搜索
|
|||
|
|
results = await security_knowledge_rag.search(
|
|||
|
|
query=query,
|
|||
|
|
category=knowledge_category,
|
|||
|
|
top_k=top_k,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not results:
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data="未找到相关的安全知识。请尝试使用不同的关键词。",
|
|||
|
|
metadata={"query": query, "results_count": 0},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 格式化结果
|
|||
|
|
formatted_results = []
|
|||
|
|
for i, result in enumerate(results, 1):
|
|||
|
|
formatted = f"### 结果 {i}"
|
|||
|
|
if result.get("title"):
|
|||
|
|
formatted += f": {result['title']}"
|
|||
|
|
formatted += f"\n相关度: {result.get('score', 0):.2f}\n"
|
|||
|
|
if result.get("tags"):
|
|||
|
|
formatted += f"标签: {', '.join(result['tags'])}\n"
|
|||
|
|
if result.get("cwe_ids"):
|
|||
|
|
formatted += f"CWE: {', '.join(result['cwe_ids'])}\n"
|
|||
|
|
formatted += f"\n{result.get('content', '')}"
|
|||
|
|
formatted_results.append(formatted)
|
|||
|
|
|
|||
|
|
output = f"找到 {len(results)} 条相关知识:\n\n" + "\n\n---\n\n".join(formatted_results)
|
|||
|
|
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data=output,
|
|||
|
|
metadata={
|
|||
|
|
"query": query,
|
|||
|
|
"results_count": len(results),
|
|||
|
|
"results": results,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Knowledge query failed: {e}")
|
|||
|
|
return ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=f"知识查询失败: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class VulnerabilityKnowledgeInput(BaseModel):
|
|||
|
|
"""漏洞知识查询输入"""
|
|||
|
|
vulnerability_type: str = Field(
|
|||
|
|
...,
|
|||
|
|
description="漏洞类型,如: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, hardcoded_secrets, auth_bypass"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GetVulnerabilityKnowledgeTool(AgentTool):
|
|||
|
|
"""
|
|||
|
|
获取特定漏洞类型的完整知识
|
|||
|
|
|
|||
|
|
返回该漏洞类型的检测方法、危险模式、修复建议等完整信息
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def name(self) -> str:
|
|||
|
|
return "get_vulnerability_knowledge"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def description(self) -> str:
|
|||
|
|
return """获取特定漏洞类型的完整专业知识。
|
|||
|
|
|
|||
|
|
支持的漏洞类型:
|
|||
|
|
- sql_injection: SQL注入
|
|||
|
|
- xss: 跨站脚本攻击
|
|||
|
|
- command_injection: 命令注入
|
|||
|
|
- path_traversal: 路径遍历
|
|||
|
|
- ssrf: 服务端请求伪造
|
|||
|
|
- deserialization: 不安全的反序列化
|
|||
|
|
- hardcoded_secrets: 硬编码凭证
|
|||
|
|
- auth_bypass: 认证绕过
|
|||
|
|
|
|||
|
|
返回内容包括:
|
|||
|
|
- 漏洞概述和危害
|
|||
|
|
- 危险代码模式
|
|||
|
|
- 检测方法
|
|||
|
|
- 安全实践
|
|||
|
|
- 修复示例
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self) -> Type[BaseModel]:
|
|||
|
|
return VulnerabilityKnowledgeInput
|
|||
|
|
|
|||
|
|
async def _execute(self, vulnerability_type: str) -> ToolResult:
|
|||
|
|
"""获取漏洞知识"""
|
|||
|
|
try:
|
|||
|
|
knowledge = await security_knowledge_rag.get_vulnerability_knowledge(
|
|||
|
|
vulnerability_type
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not knowledge:
|
|||
|
|
available = security_knowledge_rag.get_all_vulnerability_types()
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data=f"未找到漏洞类型 '{vulnerability_type}' 的知识。\n\n可用的漏洞类型: {', '.join(available)}",
|
|||
|
|
metadata={"available_types": available},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 格式化输出
|
|||
|
|
output_parts = [
|
|||
|
|
f"# {knowledge.get('title', vulnerability_type)}",
|
|||
|
|
f"严重程度: {knowledge.get('severity', 'N/A')}",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if knowledge.get("cwe_ids"):
|
|||
|
|
output_parts.append(f"CWE: {', '.join(knowledge['cwe_ids'])}")
|
|||
|
|
if knowledge.get("owasp_ids"):
|
|||
|
|
output_parts.append(f"OWASP: {', '.join(knowledge['owasp_ids'])}")
|
|||
|
|
|
|||
|
|
output_parts.append("")
|
|||
|
|
output_parts.append(knowledge.get("content", ""))
|
|||
|
|
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data="\n".join(output_parts),
|
|||
|
|
metadata=knowledge,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Get vulnerability knowledge failed: {e}")
|
|||
|
|
return ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=f"获取漏洞知识失败: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ListKnowledgeModulesInput(BaseModel):
|
|||
|
|
"""列出知识模块输入"""
|
|||
|
|
category: Optional[str] = Field(
|
|||
|
|
None,
|
|||
|
|
description="按类别过滤: vulnerability, best_practice, remediation"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ListKnowledgeModulesTool(AgentTool):
|
|||
|
|
"""
|
|||
|
|
列出所有可用的知识模块
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def name(self) -> str:
|
|||
|
|
return "list_knowledge_modules"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def description(self) -> str:
|
|||
|
|
return "列出所有可用的安全知识模块,包括漏洞类型、最佳实践等"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self) -> Type[BaseModel]:
|
|||
|
|
return ListKnowledgeModulesInput
|
|||
|
|
|
|||
|
|
async def _execute(self, category: Optional[str] = None) -> ToolResult:
|
|||
|
|
"""列出知识模块"""
|
|||
|
|
try:
|
|||
|
|
modules = security_knowledge_rag.get_all_vulnerability_types()
|
|||
|
|
|
|||
|
|
output = "可用的安全知识模块:\n\n"
|
|||
|
|
output += "## 漏洞类型\n"
|
|||
|
|
for module in modules:
|
|||
|
|
output += f"- {module}\n"
|
|||
|
|
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data=output,
|
|||
|
|
metadata={"modules": modules},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"List knowledge modules failed: {e}")
|
|||
|
|
return ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=f"列出知识模块失败: {str(e)}",
|
|||
|
|
)
|