CodeReview/backend/app/services/agent/tools/code_analysis_tool.py

428 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
代码分析工具
使用 LLM 深度分析代码安全问题
"""
import json
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
class CodeAnalysisInput(BaseModel):
"""代码分析输入"""
code: str = Field(description="要分析的代码内容")
file_path: str = Field(default="unknown", description="文件路径")
language: str = Field(default="python", description="编程语言")
focus: Optional[str] = Field(
default=None,
description="重点关注的漏洞类型,如 sql_injection, xss, command_injection"
)
context: Optional[str] = Field(
default=None,
description="额外的上下文信息,如相关的其他代码片段"
)
class CodeAnalysisTool(AgentTool):
"""
代码分析工具
使用 LLM 对代码进行深度安全分析
"""
def __init__(self, llm_service):
"""
初始化代码分析工具
Args:
llm_service: LLM 服务实例
"""
super().__init__()
self.llm_service = llm_service
@property
def name(self) -> str:
return "code_analysis"
@property
def description(self) -> str:
return """深度分析代码安全问题。
使用 LLM 对代码进行全面的安全审计,识别潜在漏洞。
使用场景:
- 对疑似有问题的代码进行深入分析
- 分析复杂的业务逻辑漏洞
- 追踪数据流和污点传播
- 生成详细的漏洞报告和修复建议
输入:
- code: 要分析的代码
- file_path: 文件路径
- language: 编程语言
- focus: 可选,重点关注的漏洞类型
- context: 可选,额外的上下文代码
这个工具会消耗较多的 Token建议在确认有疑似问题后使用。"""
@property
def args_schema(self):
return CodeAnalysisInput
async def _execute(
self,
code: str,
file_path: str = "unknown",
language: str = "python",
focus: Optional[str] = None,
context: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行代码分析"""
try:
# 构建分析结果
analysis = await self.llm_service.analyze_code(code, language)
issues = analysis.get("issues", [])
if not issues:
return ToolResult(
success=True,
data="代码分析完成,未发现明显的安全问题。\n\n"
f"质量评分: {analysis.get('quality_score', 'N/A')}\n"
f"文件: {file_path}",
metadata={
"file_path": file_path,
"issues_count": 0,
"quality_score": analysis.get("quality_score"),
}
)
# 格式化输出
output_parts = [f"🔍 代码分析结果 - {file_path}\n"]
output_parts.append(f"发现 {len(issues)} 个问题:\n")
for i, issue in enumerate(issues):
severity_icon = {
"critical": "🔴",
"high": "🟠",
"medium": "🟡",
"low": "🟢"
}.get(issue.get("severity", ""), "")
output_parts.append(f"\n{severity_icon} 问题 {i+1}: {issue.get('title', 'Unknown')}")
output_parts.append(f" 类型: {issue.get('type', 'unknown')}")
output_parts.append(f" 严重程度: {issue.get('severity', 'unknown')}")
output_parts.append(f" 行号: {issue.get('line', 'N/A')}")
output_parts.append(f" 描述: {issue.get('description', '')}")
if issue.get("code_snippet"):
output_parts.append(f" 代码片段:\n ```\n {issue.get('code_snippet')}\n ```")
if issue.get("suggestion"):
output_parts.append(f" 修复建议: {issue.get('suggestion')}")
if issue.get("ai_explanation"):
output_parts.append(f" AI解释: {issue.get('ai_explanation')}")
output_parts.append(f"\n质量评分: {analysis.get('quality_score', 'N/A')}/100")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"file_path": file_path,
"issues_count": len(issues),
"quality_score": analysis.get("quality_score"),
"issues": issues,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"代码分析失败: {str(e)}",
)
class DataFlowAnalysisInput(BaseModel):
"""数据流分析输入"""
source_code: str = Field(description="包含数据源的代码")
sink_code: Optional[str] = Field(default=None, description="包含数据汇的代码(如危险函数)")
variable_name: str = Field(description="要追踪的变量名")
file_path: str = Field(default="unknown", description="文件路径")
class DataFlowAnalysisTool(AgentTool):
"""
数据流分析工具
追踪变量从源到汇的数据流
"""
def __init__(self, llm_service):
super().__init__()
self.llm_service = llm_service
@property
def name(self) -> str:
return "dataflow_analysis"
@property
def description(self) -> str:
return """分析代码中的数据流,追踪变量从源(如用户输入)到汇(如危险函数)的路径。
使用场景:
- 追踪用户输入如何流向危险函数
- 分析变量是否经过净化处理
- 识别污点传播路径
输入:
- source_code: 包含数据源的代码
- sink_code: 包含数据汇的代码(可选)
- variable_name: 要追踪的变量名
- file_path: 文件路径"""
@property
def args_schema(self):
return DataFlowAnalysisInput
async def _execute(
self,
source_code: str,
variable_name: str,
sink_code: Optional[str] = None,
file_path: str = "unknown",
**kwargs
) -> ToolResult:
"""执行数据流分析"""
try:
# 构建分析 prompt
analysis_prompt = f"""分析以下代码中变量 '{variable_name}' 的数据流。
源代码:
```
{source_code}
```
"""
if sink_code:
analysis_prompt += f"""
汇代码(可能的危险函数):
```
{sink_code}
```
"""
analysis_prompt += f"""
请分析:
1. 变量 '{variable_name}' 的来源是什么?(用户输入、配置、数据库等)
2. 变量在传递过程中是否经过了净化/验证?
3. 变量最终流向了哪些危险函数?
4. 是否存在安全风险?
请返回 JSON 格式的分析结果,包含:
- source_type: 数据源类型
- sanitized: 是否经过净化
- sanitization_methods: 使用的净化方法
- dangerous_sinks: 流向的危险函数列表
- risk_level: 风险等级 (high/medium/low/none)
- explanation: 详细解释
- recommendation: 建议
"""
# 调用 LLM 分析
# 这里使用 analyze_code_with_custom_prompt
result = await self.llm_service.analyze_code_with_custom_prompt(
code=source_code,
language="text",
custom_prompt=analysis_prompt,
)
# 格式化输出
output_parts = [f"📊 数据流分析结果 - 变量: {variable_name}\n"]
if isinstance(result, dict):
if result.get("source_type"):
output_parts.append(f"数据源: {result.get('source_type')}")
if result.get("sanitized") is not None:
sanitized = "✅ 是" if result.get("sanitized") else "❌ 否"
output_parts.append(f"是否净化: {sanitized}")
if result.get("sanitization_methods"):
output_parts.append(f"净化方法: {', '.join(result.get('sanitization_methods', []))}")
if result.get("dangerous_sinks"):
output_parts.append(f"危险函数: {', '.join(result.get('dangerous_sinks', []))}")
if result.get("risk_level"):
risk_icons = {"high": "🔴", "medium": "🟠", "low": "🟡", "none": "🟢"}
icon = risk_icons.get(result.get("risk_level", ""), "")
output_parts.append(f"风险等级: {icon} {result.get('risk_level', '').upper()}")
if result.get("explanation"):
output_parts.append(f"\n分析: {result.get('explanation')}")
if result.get("recommendation"):
output_parts.append(f"\n建议: {result.get('recommendation')}")
else:
output_parts.append(str(result))
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"variable": variable_name,
"file_path": file_path,
"analysis": result,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"数据流分析失败: {str(e)}",
)
class VulnerabilityValidationInput(BaseModel):
"""漏洞验证输入"""
code: str = Field(description="可能存在漏洞的代码")
vulnerability_type: str = Field(description="漏洞类型")
file_path: str = Field(default="unknown", description="文件路径")
line_number: Optional[int] = Field(default=None, description="行号")
context: Optional[str] = Field(default=None, description="额外上下文")
class VulnerabilityValidationTool(AgentTool):
"""
漏洞验证工具
验证疑似漏洞是否真实存在
"""
def __init__(self, llm_service):
super().__init__()
self.llm_service = llm_service
@property
def name(self) -> str:
return "vulnerability_validation"
@property
def description(self) -> str:
return """验证疑似漏洞是否真实存在。
对发现的潜在漏洞进行深入分析,判断是否为真正的安全问题。
输入:
- code: 包含疑似漏洞的代码
- vulnerability_type: 漏洞类型(如 sql_injection, xss 等)
- file_path: 文件路径
- line_number: 可选,行号
- context: 可选,额外的上下文代码
输出:
- 验证结果(确认/可能/误报)
- 详细分析
- 利用条件
- PoC 思路(如果确认存在漏洞)"""
@property
def args_schema(self):
return VulnerabilityValidationInput
async def _execute(
self,
code: str,
vulnerability_type: str,
file_path: str = "unknown",
line_number: Optional[int] = None,
context: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行漏洞验证"""
try:
validation_prompt = f"""你是一个专业的安全研究员,请验证以下代码中是否真的存在 {vulnerability_type} 漏洞。
代码:
```
{code}
```
{f'额外上下文:' + chr(10) + '```' + chr(10) + context + chr(10) + '```' if context else ''}
请分析:
1. 这段代码是否真的存在 {vulnerability_type} 漏洞?
2. 漏洞的利用条件是什么?
3. 攻击者如何利用这个漏洞?
4. 这是否可能是误报?为什么?
请返回 JSON 格式:
{{
"is_vulnerable": true/false/null (null表示无法确定),
"confidence": 0.0-1.0,
"verdict": "confirmed/likely/unlikely/false_positive",
"exploitation_conditions": ["条件1", "条件2"],
"attack_vector": "攻击向量描述",
"poc_idea": "PoC思路如果存在漏洞",
"false_positive_reason": "如果是误报,说明原因",
"detailed_analysis": "详细分析"
}}
"""
result = await self.llm_service.analyze_code_with_custom_prompt(
code=code,
language="text",
custom_prompt=validation_prompt,
)
# 格式化输出
output_parts = [f"🔎 漏洞验证结果 - {vulnerability_type}\n"]
output_parts.append(f"文件: {file_path}")
if line_number:
output_parts.append(f"行号: {line_number}")
output_parts.append("")
if isinstance(result, dict):
# 验证结果
verdict_icons = {
"confirmed": "🔴 确认存在漏洞",
"likely": "🟠 可能存在漏洞",
"unlikely": "🟡 可能是误报",
"false_positive": "🟢 误报",
}
verdict = result.get("verdict", "unknown")
output_parts.append(f"判定: {verdict_icons.get(verdict, verdict)}")
if result.get("confidence"):
output_parts.append(f"置信度: {result.get('confidence') * 100:.0f}%")
if result.get("exploitation_conditions"):
output_parts.append(f"\n利用条件:")
for cond in result.get("exploitation_conditions", []):
output_parts.append(f" - {cond}")
if result.get("attack_vector"):
output_parts.append(f"\n攻击向量: {result.get('attack_vector')}")
if result.get("poc_idea") and verdict in ["confirmed", "likely"]:
output_parts.append(f"\nPoC思路: {result.get('poc_idea')}")
if result.get("false_positive_reason") and verdict in ["unlikely", "false_positive"]:
output_parts.append(f"\n误报原因: {result.get('false_positive_reason')}")
if result.get("detailed_analysis"):
output_parts.append(f"\n详细分析:\n{result.get('detailed_analysis')}")
else:
output_parts.append(str(result))
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"vulnerability_type": vulnerability_type,
"file_path": file_path,
"line_number": line_number,
"validation": result,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"漏洞验证失败: {str(e)}",
)