CodeReview/backend/app/services/agent/knowledge/loader.py

208 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识加载器 - 基于RAG的知识模块加载
将安全知识集成到Agent的系统提示词中
"""
import logging
from typing import List, Dict, Any, Optional
from .base import KnowledgeCategory
logger = logging.getLogger(__name__)
class KnowledgeLoader:
"""
知识加载器
负责将RAG检索的知识集成到Agent系统提示词中
"""
def __init__(self, rag=None):
# 延迟导入避免循环依赖
if rag is None:
from .rag_knowledge import security_knowledge_rag
rag = security_knowledge_rag
self._rag = rag
async def load_module(self, module_name: str) -> str:
"""
加载单个知识模块
Args:
module_name: 模块名称如sql_injection, xss等
Returns:
模块内容
"""
knowledge = await self._rag.get_vulnerability_knowledge(module_name)
if knowledge:
return knowledge.get("content", "")
return ""
async def load_modules(self, module_names: List[str]) -> Dict[str, str]:
"""
批量加载知识模块
Args:
module_names: 模块名称列表
Returns:
模块名称到内容的映射
"""
result = {}
for name in module_names:
content = await self.load_module(name)
if content:
result[name] = content
return result
async def search_knowledge(
self,
query: str,
top_k: int = 3,
) -> List[Dict[str, Any]]:
"""
搜索相关知识
Args:
query: 搜索查询
top_k: 返回数量
Returns:
相关知识列表
"""
return await self._rag.search(query, top_k=top_k)
def build_system_prompt_with_modules(
self,
base_prompt: str,
module_names: List[str],
) -> str:
"""
构建包含知识模块的系统提示词(同步版本,使用内置知识)
Args:
base_prompt: 基础系统提示词
module_names: 要加载的模块名称列表
Returns:
增强后的系统提示词
"""
if not module_names:
return base_prompt
# 使用内置知识(同步)
knowledge_sections = []
for name in module_names:
knowledge = self._get_builtin_knowledge(name)
if knowledge:
knowledge_sections.append(f"### {knowledge['title']}\n{knowledge['content']}")
if not knowledge_sections:
return base_prompt
knowledge_text = "\n\n".join(knowledge_sections)
return f"""{base_prompt}
---
## 专业安全知识参考
以下是与当前任务相关的安全知识,请在分析时参考:
{knowledge_text}
---
"""
def _get_builtin_knowledge(self, module_name: str) -> Optional[Dict[str, Any]]:
"""获取内置知识(同步)"""
module_name_normalized = module_name.lower().replace("-", "_").replace(" ", "_")
for doc in self._rag._builtin_knowledge:
if doc.id == f"vuln_{module_name_normalized}" or doc.id == module_name_normalized:
return doc.to_dict()
# 模糊匹配
for doc in self._rag._builtin_knowledge:
if module_name_normalized in doc.id or any(
module_name_normalized in tag for tag in doc.tags
):
return doc.to_dict()
return None
def get_available_modules(self) -> List[str]:
"""获取所有可用的知识模块"""
return self._rag.get_all_vulnerability_types()
def get_all_module_names(self) -> List[str]:
"""获取所有模块名称(包括漏洞和框架)"""
vuln_types = self._rag.get_all_vulnerability_types()
frameworks = self._rag.get_all_frameworks()
return vuln_types + frameworks
def validate_modules(self, module_names: List[str]) -> Dict[str, List[str]]:
"""
验证知识模块是否存在
Args:
module_names: 要验证的模块名称列表
Returns:
{"valid": [...], "invalid": [...]}
"""
all_modules = self.get_all_module_names()
all_modules_normalized = {m.lower().replace("-", "_") for m in all_modules}
# 添加常见别名
aliases = {
"sql": "sql_injection",
"sqli": "sql_injection",
"xss": "xss_reflected",
"auth": "auth_bypass",
"idor": "idor",
"ssrf": "ssrf",
"rce": "command_injection",
"lfi": "path_traversal",
"xxe": "xxe",
}
valid = []
invalid = []
for name in module_names:
name_normalized = name.lower().replace("-", "_").replace(" ", "_")
# 检查直接匹配
if name_normalized in all_modules_normalized:
valid.append(name)
# 检查别名
elif name_normalized in aliases:
valid.append(aliases[name_normalized])
# 检查部分匹配
elif any(name_normalized in m for m in all_modules_normalized):
valid.append(name)
else:
invalid.append(name)
return {"valid": valid, "invalid": invalid}
# 全局实例
knowledge_loader = KnowledgeLoader()
# 便捷函数
def get_available_modules() -> List[str]:
"""获取所有可用的知识模块"""
return knowledge_loader.get_available_modules()
def get_module_content(module_name: str) -> Optional[str]:
"""获取模块内容(同步)"""
knowledge = knowledge_loader._get_builtin_knowledge(module_name)
return knowledge.get("content") if knowledge else None