208 lines
5.9 KiB
Python
208 lines
5.9 KiB
Python
"""
|
||
知识加载器 - 基于RAG的知识模块加载
|
||
|
||
将安全知识集成到Agent的系统提示词中
|
||
"""
|
||
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
from .base import KnowledgeCategory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class KnowledgeLoader:
|
||
"""
|
||
知识加载器
|
||
|
||
负责将RAG检索的知识集成到Agent系统提示词中
|
||
"""
|
||
|
||
def __init__(self, rag=None):
|
||
# 延迟导入避免循环依赖
|
||
if rag is None:
|
||
from .rag_knowledge import security_knowledge_rag
|
||
rag = security_knowledge_rag
|
||
self._rag = rag
|
||
|
||
async def load_module(self, module_name: str) -> str:
|
||
"""
|
||
加载单个知识模块
|
||
|
||
Args:
|
||
module_name: 模块名称(如sql_injection, xss等)
|
||
|
||
Returns:
|
||
模块内容
|
||
"""
|
||
knowledge = await self._rag.get_vulnerability_knowledge(module_name)
|
||
if knowledge:
|
||
return knowledge.get("content", "")
|
||
return ""
|
||
|
||
async def load_modules(self, module_names: List[str]) -> Dict[str, str]:
|
||
"""
|
||
批量加载知识模块
|
||
|
||
Args:
|
||
module_names: 模块名称列表
|
||
|
||
Returns:
|
||
模块名称到内容的映射
|
||
"""
|
||
result = {}
|
||
for name in module_names:
|
||
content = await self.load_module(name)
|
||
if content:
|
||
result[name] = content
|
||
return result
|
||
|
||
async def search_knowledge(
|
||
self,
|
||
query: str,
|
||
top_k: int = 3,
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
搜索相关知识
|
||
|
||
Args:
|
||
query: 搜索查询
|
||
top_k: 返回数量
|
||
|
||
Returns:
|
||
相关知识列表
|
||
"""
|
||
return await self._rag.search(query, top_k=top_k)
|
||
|
||
def build_system_prompt_with_modules(
|
||
self,
|
||
base_prompt: str,
|
||
module_names: List[str],
|
||
) -> str:
|
||
"""
|
||
构建包含知识模块的系统提示词(同步版本,使用内置知识)
|
||
|
||
Args:
|
||
base_prompt: 基础系统提示词
|
||
module_names: 要加载的模块名称列表
|
||
|
||
Returns:
|
||
增强后的系统提示词
|
||
"""
|
||
if not module_names:
|
||
return base_prompt
|
||
|
||
# 使用内置知识(同步)
|
||
knowledge_sections = []
|
||
for name in module_names:
|
||
knowledge = self._get_builtin_knowledge(name)
|
||
if knowledge:
|
||
knowledge_sections.append(f"### {knowledge['title']}\n{knowledge['content']}")
|
||
|
||
if not knowledge_sections:
|
||
return base_prompt
|
||
|
||
knowledge_text = "\n\n".join(knowledge_sections)
|
||
|
||
return f"""{base_prompt}
|
||
|
||
---
|
||
## 专业安全知识参考
|
||
|
||
以下是与当前任务相关的安全知识,请在分析时参考:
|
||
|
||
{knowledge_text}
|
||
|
||
---
|
||
"""
|
||
|
||
def _get_builtin_knowledge(self, module_name: str) -> Optional[Dict[str, Any]]:
|
||
"""获取内置知识(同步)"""
|
||
module_name_normalized = module_name.lower().replace("-", "_").replace(" ", "_")
|
||
|
||
for doc in self._rag._builtin_knowledge:
|
||
if doc.id == f"vuln_{module_name_normalized}" or doc.id == module_name_normalized:
|
||
return doc.to_dict()
|
||
|
||
# 模糊匹配
|
||
for doc in self._rag._builtin_knowledge:
|
||
if module_name_normalized in doc.id or any(
|
||
module_name_normalized in tag for tag in doc.tags
|
||
):
|
||
return doc.to_dict()
|
||
|
||
return None
|
||
|
||
def get_available_modules(self) -> List[str]:
|
||
"""获取所有可用的知识模块"""
|
||
return self._rag.get_all_vulnerability_types()
|
||
|
||
def get_all_module_names(self) -> List[str]:
|
||
"""获取所有模块名称(包括漏洞和框架)"""
|
||
vuln_types = self._rag.get_all_vulnerability_types()
|
||
frameworks = self._rag.get_all_frameworks()
|
||
return vuln_types + frameworks
|
||
|
||
def validate_modules(self, module_names: List[str]) -> Dict[str, List[str]]:
|
||
"""
|
||
验证知识模块是否存在
|
||
|
||
Args:
|
||
module_names: 要验证的模块名称列表
|
||
|
||
Returns:
|
||
{"valid": [...], "invalid": [...]}
|
||
"""
|
||
all_modules = self.get_all_module_names()
|
||
all_modules_normalized = {m.lower().replace("-", "_") for m in all_modules}
|
||
|
||
# 添加常见别名
|
||
aliases = {
|
||
"sql": "sql_injection",
|
||
"sqli": "sql_injection",
|
||
"xss": "xss_reflected",
|
||
"auth": "auth_bypass",
|
||
"idor": "idor",
|
||
"ssrf": "ssrf",
|
||
"rce": "command_injection",
|
||
"lfi": "path_traversal",
|
||
"xxe": "xxe",
|
||
}
|
||
|
||
valid = []
|
||
invalid = []
|
||
|
||
for name in module_names:
|
||
name_normalized = name.lower().replace("-", "_").replace(" ", "_")
|
||
|
||
# 检查直接匹配
|
||
if name_normalized in all_modules_normalized:
|
||
valid.append(name)
|
||
# 检查别名
|
||
elif name_normalized in aliases:
|
||
valid.append(aliases[name_normalized])
|
||
# 检查部分匹配
|
||
elif any(name_normalized in m for m in all_modules_normalized):
|
||
valid.append(name)
|
||
else:
|
||
invalid.append(name)
|
||
|
||
return {"valid": valid, "invalid": invalid}
|
||
|
||
|
||
# 全局实例
|
||
knowledge_loader = KnowledgeLoader()
|
||
|
||
|
||
# 便捷函数
|
||
def get_available_modules() -> List[str]:
|
||
"""获取所有可用的知识模块"""
|
||
return knowledge_loader.get_available_modules()
|
||
|
||
|
||
def get_module_content(module_name: str) -> Optional[str]:
|
||
"""获取模块内容(同步)"""
|
||
knowledge = knowledge_loader._get_builtin_knowledge(module_name)
|
||
return knowledge.get("content") if knowledge else None
|