CodeReview/backend/app/services/agent/knowledge/loader.py

208 lines
5.9 KiB
Python
Raw Normal View History

"""
知识加载器 - 基于RAG的知识模块加载
将安全知识集成到Agent的系统提示词中
"""
import logging
from typing import List, Dict, Any, Optional
from .base import KnowledgeCategory
logger = logging.getLogger(__name__)
class KnowledgeLoader:
"""
知识加载器
负责将RAG检索的知识集成到Agent系统提示词中
"""
def __init__(self, rag=None):
# 延迟导入避免循环依赖
if rag is None:
from .rag_knowledge import security_knowledge_rag
rag = security_knowledge_rag
self._rag = rag
async def load_module(self, module_name: str) -> str:
"""
加载单个知识模块
Args:
module_name: 模块名称如sql_injection, xss等
Returns:
模块内容
"""
knowledge = await self._rag.get_vulnerability_knowledge(module_name)
if knowledge:
return knowledge.get("content", "")
return ""
async def load_modules(self, module_names: List[str]) -> Dict[str, str]:
"""
批量加载知识模块
Args:
module_names: 模块名称列表
Returns:
模块名称到内容的映射
"""
result = {}
for name in module_names:
content = await self.load_module(name)
if content:
result[name] = content
return result
async def search_knowledge(
self,
query: str,
top_k: int = 3,
) -> List[Dict[str, Any]]:
"""
搜索相关知识
Args:
query: 搜索查询
top_k: 返回数量
Returns:
相关知识列表
"""
return await self._rag.search(query, top_k=top_k)
def build_system_prompt_with_modules(
self,
base_prompt: str,
module_names: List[str],
) -> str:
"""
构建包含知识模块的系统提示词同步版本使用内置知识
Args:
base_prompt: 基础系统提示词
module_names: 要加载的模块名称列表
Returns:
增强后的系统提示词
"""
if not module_names:
return base_prompt
# 使用内置知识(同步)
knowledge_sections = []
for name in module_names:
knowledge = self._get_builtin_knowledge(name)
if knowledge:
knowledge_sections.append(f"### {knowledge['title']}\n{knowledge['content']}")
if not knowledge_sections:
return base_prompt
knowledge_text = "\n\n".join(knowledge_sections)
return f"""{base_prompt}
---
## 专业安全知识参考
以下是与当前任务相关的安全知识请在分析时参考
{knowledge_text}
---
"""
def _get_builtin_knowledge(self, module_name: str) -> Optional[Dict[str, Any]]:
"""获取内置知识(同步)"""
module_name_normalized = module_name.lower().replace("-", "_").replace(" ", "_")
for doc in self._rag._builtin_knowledge:
if doc.id == f"vuln_{module_name_normalized}" or doc.id == module_name_normalized:
return doc.to_dict()
# 模糊匹配
for doc in self._rag._builtin_knowledge:
if module_name_normalized in doc.id or any(
module_name_normalized in tag for tag in doc.tags
):
return doc.to_dict()
return None
def get_available_modules(self) -> List[str]:
"""获取所有可用的知识模块"""
return self._rag.get_all_vulnerability_types()
def get_all_module_names(self) -> List[str]:
"""获取所有模块名称(包括漏洞和框架)"""
vuln_types = self._rag.get_all_vulnerability_types()
frameworks = self._rag.get_all_frameworks()
return vuln_types + frameworks
def validate_modules(self, module_names: List[str]) -> Dict[str, List[str]]:
"""
验证知识模块是否存在
Args:
module_names: 要验证的模块名称列表
Returns:
{"valid": [...], "invalid": [...]}
"""
all_modules = self.get_all_module_names()
all_modules_normalized = {m.lower().replace("-", "_") for m in all_modules}
# 添加常见别名
aliases = {
"sql": "sql_injection",
"sqli": "sql_injection",
"xss": "xss_reflected",
"auth": "auth_bypass",
"idor": "idor",
"ssrf": "ssrf",
"rce": "command_injection",
"lfi": "path_traversal",
"xxe": "xxe",
}
valid = []
invalid = []
for name in module_names:
name_normalized = name.lower().replace("-", "_").replace(" ", "_")
# 检查直接匹配
if name_normalized in all_modules_normalized:
valid.append(name)
# 检查别名
elif name_normalized in aliases:
valid.append(aliases[name_normalized])
# 检查部分匹配
elif any(name_normalized in m for m in all_modules_normalized):
valid.append(name)
else:
invalid.append(name)
return {"valid": valid, "invalid": invalid}
# 全局实例
knowledge_loader = KnowledgeLoader()
# 便捷函数
def get_available_modules() -> List[str]:
"""获取所有可用的知识模块"""
return knowledge_loader.get_available_modules()
def get_module_content(module_name: str) -> Optional[str]:
"""获取模块内容(同步)"""
knowledge = knowledge_loader._get_builtin_knowledge(module_name)
return knowledge.get("content") if knowledge else None