CodeReview/backend/app/services/agent/config.py

558 lines
18 KiB
Python
Raw Normal View History

"""
Agent Configuration Module
Centralized configuration management for the Agent audit framework.
All configuration values can be overridden via environment variables with AGENT_ prefix.
"""
import os
from typing import Any, Dict, List, Optional, Set
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from pydantic import Field
from pydantic_settings import BaseSettings
class LogLevel(str, Enum):
"""Logging levels"""
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
ERROR = "ERROR"
CRITICAL = "CRITICAL"
class AgentConfig(BaseSettings):
"""
Centralized configuration for the Agent audit framework.
All settings can be overridden via environment variables with AGENT_ prefix.
Example: AGENT_LLM_MAX_RETRIES=5
"""
# ============ LLM Settings ============
llm_max_retries: int = Field(
default=3,
description="Maximum number of retries for LLM calls"
)
llm_retry_base_delay: float = Field(
default=1.0,
description="Base delay in seconds for LLM retry backoff"
)
llm_retry_max_delay: float = Field(
default=60.0,
description="Maximum delay in seconds for LLM retry backoff"
)
llm_timeout_seconds: int = Field(
default=120,
description="Timeout in seconds for LLM requests"
)
llm_max_tokens_per_call: int = Field(
default=4096,
description="Maximum tokens per LLM call"
)
llm_temperature: float = Field(
default=0.1,
description="Default temperature for LLM calls"
)
llm_stream_enabled: bool = Field(
default=True,
description="Enable streaming for LLM responses"
)
# ============ Agent Iteration Limits ============
orchestrator_max_iterations: int = Field(
default=20,
description="Maximum iterations for orchestrator agent"
)
recon_max_iterations: int = Field(
default=15,
description="Maximum iterations for reconnaissance agent"
)
analysis_max_iterations: int = Field(
default=30,
description="Maximum iterations for analysis agent"
)
verification_max_iterations: int = Field(
default=15,
description="Maximum iterations for verification agent"
)
# ============ Agent Timeouts ============
orchestrator_timeout_seconds: int = Field(
default=1800,
description="Timeout in seconds for orchestrator (30 minutes)"
)
sub_agent_timeout_seconds: int = Field(
default=600,
description="Timeout in seconds for sub-agents (10 minutes)"
)
# ============ Tool Settings ============
tool_timeout_seconds: int = Field(
default=60,
description="Default timeout for tool execution"
)
tool_max_retries: int = Field(
default=2,
description="Maximum retries for tool execution"
)
# External Tool Toggles
semgrep_enabled: bool = Field(
default=True,
description="Enable Semgrep scanner"
)
bandit_enabled: bool = Field(
default=True,
description="Enable Bandit scanner"
)
gitleaks_enabled: bool = Field(
default=True,
description="Enable Gitleaks scanner"
)
npm_audit_enabled: bool = Field(
default=True,
description="Enable npm audit"
)
safety_enabled: bool = Field(
default=True,
description="Enable Safety (Python) scanner"
)
osv_scanner_enabled: bool = Field(
default=True,
description="Enable OSV scanner"
)
# Kunlun-M (MIT License - https://github.com/LoRexxar/Kunlun-M)
kunlun_enabled: bool = Field(
default=True,
description="Enable Kunlun-M static code analyzer"
)
# External Tool Timeouts
semgrep_timeout_seconds: int = Field(
default=120,
description="Timeout for Semgrep scanner"
)
bandit_timeout_seconds: int = Field(
default=60,
description="Timeout for Bandit scanner"
)
gitleaks_timeout_seconds: int = Field(
default=60,
description="Timeout for Gitleaks scanner"
)
kunlun_timeout_seconds: int = Field(
default=600,
description="Timeout for Kunlun-M scanner (10 minutes for deep analysis)"
)
# ============ Rate Limiting ============
rate_limit_enabled: bool = Field(
default=True,
description="Enable rate limiting for tools"
)
external_tool_rate_per_second: float = Field(
default=0.2,
description="Rate limit for external tools (calls per second)"
)
external_tool_burst: int = Field(
default=3,
description="Burst limit for external tools"
)
llm_rate_per_minute: int = Field(
default=60,
description="Rate limit for LLM calls per minute"
)
# ============ Circuit Breaker ============
circuit_breaker_enabled: bool = Field(
default=True,
description="Enable circuit breaker pattern"
)
circuit_failure_threshold: int = Field(
default=5,
description="Number of failures before circuit opens"
)
circuit_recovery_timeout_seconds: float = Field(
default=30.0,
description="Time to wait before attempting recovery"
)
circuit_half_open_max_calls: int = Field(
default=3,
description="Max calls in half-open state before closing"
)
# ============ Resource Limits ============
max_file_size_bytes: int = Field(
default=10 * 1024 * 1024, # 10MB
description="Maximum file size to analyze"
)
max_files_per_scan: int = Field(
default=1000,
description="Maximum files to scan per task"
)
max_findings_per_agent: int = Field(
default=100,
description="Maximum findings per agent before stopping"
)
max_total_findings: int = Field(
default=500,
description="Maximum total findings per task"
)
max_context_messages: int = Field(
default=50,
description="Maximum messages in agent context"
)
max_tool_output_length: int = Field(
default=50000,
description="Maximum length of tool output"
)
# ============ Checkpoint & Persistence ============
checkpoint_enabled: bool = Field(
default=True,
description="Enable automatic checkpointing"
)
checkpoint_interval_iterations: int = Field(
default=5,
description="Create checkpoint every N iterations"
)
checkpoint_on_tool_complete: bool = Field(
default=False,
description="Create checkpoint after each tool completion"
)
checkpoint_on_phase_complete: bool = Field(
default=True,
description="Create checkpoint after each phase"
)
max_checkpoints_per_task: int = Field(
default=50,
description="Maximum checkpoints to keep per task"
)
# ============ Logging & Telemetry ============
log_level: LogLevel = Field(
default=LogLevel.INFO,
description="Logging level"
)
structured_logging_enabled: bool = Field(
default=True,
description="Enable structured JSON logging"
)
telemetry_enabled: bool = Field(
default=True,
description="Enable telemetry tracing"
)
log_llm_prompts: bool = Field(
default=False,
description="Log full LLM prompts (may contain sensitive data)"
)
log_llm_responses: bool = Field(
default=False,
description="Log full LLM responses"
)
log_tool_inputs: bool = Field(
default=True,
description="Log tool input parameters"
)
log_tool_outputs: bool = Field(
default=False,
description="Log full tool outputs"
)
# ============ Event Streaming ============
sse_heartbeat_interval_seconds: int = Field(
default=30,
description="SSE heartbeat interval"
)
event_queue_max_size: int = Field(
default=1000,
description="Maximum events in queue"
)
event_batch_size: int = Field(
default=10,
description="Events to batch for persistence"
)
# ============ Security ============
allowed_file_extensions: Set[str] = Field(
default={
".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rb", ".php",
".c", ".cpp", ".h", ".hpp", ".cs", ".swift", ".kt", ".rs", ".scala",
".vue", ".svelte", ".html", ".css", ".scss", ".sass", ".less",
".json", ".yaml", ".yml", ".xml", ".toml", ".ini", ".conf",
".sql", ".graphql", ".proto", ".sh", ".bash", ".zsh", ".ps1",
".md", ".txt", ".rst", ".env.example", ".gitignore",
},
description="Allowed file extensions for analysis"
)
blocked_directories: Set[str] = Field(
default={
"node_modules", "__pycache__", ".git", ".svn", ".hg",
"venv", ".venv", "env", ".env", "virtualenv",
"dist", "build", "target", "out", "bin", "obj",
".idea", ".vscode", ".vs", ".pytest_cache", ".mypy_cache",
"coverage", ".coverage", "htmlcov", ".tox", ".nox",
},
description="Directories to exclude from scanning"
)
max_path_depth: int = Field(
default=20,
description="Maximum directory depth to scan"
)
# ============ Knowledge & RAG ============
rag_enabled: bool = Field(
default=True,
description="Enable RAG-based knowledge retrieval"
)
rag_top_k: int = Field(
default=5,
description="Number of RAG results to retrieve"
)
knowledge_modules_enabled: bool = Field(
default=True,
description="Enable knowledge module injection"
)
# ============ Graceful Degradation ============
fallback_enabled: bool = Field(
default=True,
description="Enable fallback strategies on failure"
)
continue_on_tool_failure: bool = Field(
default=True,
description="Continue execution if a tool fails"
)
continue_on_partial_results: bool = Field(
default=True,
description="Continue with partial results on timeout"
)
class Config:
env_prefix = "AGENT_"
env_file = ".env"
case_sensitive = False
extra = "ignore"
@dataclass
class ToolConfig:
"""Configuration for a specific tool"""
name: str
enabled: bool = True
timeout_seconds: int = 60
max_retries: int = 2
rate_limit_per_second: Optional[float] = None
fallback_tool: Optional[str] = None
circuit_breaker_enabled: bool = True
@dataclass
class AgentTypeConfig:
"""Configuration for a specific agent type"""
agent_type: str
max_iterations: int
timeout_seconds: int
tools: List[str] = field(default_factory=list)
knowledge_modules: List[str] = field(default_factory=list)
# ============ Configuration Factory ============
@lru_cache()
def get_agent_config() -> AgentConfig:
"""
Get the singleton agent configuration instance.
Configuration is loaded once and cached. To reload configuration,
call get_agent_config.cache_clear() first.
"""
return AgentConfig()
def get_tool_config(tool_name: str) -> ToolConfig:
"""Get configuration for a specific tool"""
config = get_agent_config()
# Tool-specific configurations
tool_configs: Dict[str, ToolConfig] = {
"semgrep_scan": ToolConfig(
name="semgrep_scan",
enabled=config.semgrep_enabled,
timeout_seconds=config.semgrep_timeout_seconds,
rate_limit_per_second=config.external_tool_rate_per_second,
fallback_tool="pattern_match",
),
"bandit_scan": ToolConfig(
name="bandit_scan",
enabled=config.bandit_enabled,
timeout_seconds=config.bandit_timeout_seconds,
rate_limit_per_second=config.external_tool_rate_per_second,
fallback_tool="pattern_match",
),
"gitleaks_scan": ToolConfig(
name="gitleaks_scan",
enabled=config.gitleaks_enabled,
timeout_seconds=config.gitleaks_timeout_seconds,
rate_limit_per_second=config.external_tool_rate_per_second,
),
"npm_audit": ToolConfig(
name="npm_audit",
enabled=config.npm_audit_enabled,
timeout_seconds=config.tool_timeout_seconds,
rate_limit_per_second=config.external_tool_rate_per_second,
),
"safety_check": ToolConfig(
name="safety_check",
enabled=config.safety_enabled,
timeout_seconds=config.tool_timeout_seconds,
rate_limit_per_second=config.external_tool_rate_per_second,
),
"osv_scanner": ToolConfig(
name="osv_scanner",
enabled=config.osv_scanner_enabled,
timeout_seconds=config.tool_timeout_seconds,
rate_limit_per_second=config.external_tool_rate_per_second,
),
}
return tool_configs.get(
tool_name,
ToolConfig(
name=tool_name,
timeout_seconds=config.tool_timeout_seconds,
max_retries=config.tool_max_retries,
)
)
def get_agent_type_config(agent_type: str) -> AgentTypeConfig:
"""Get configuration for a specific agent type"""
config = get_agent_config()
agent_configs = {
"orchestrator": AgentTypeConfig(
agent_type="orchestrator",
max_iterations=config.orchestrator_max_iterations,
timeout_seconds=config.orchestrator_timeout_seconds,
tools=["think", "reflect", "dispatch_agent", "finish"],
),
"recon": AgentTypeConfig(
agent_type="recon",
max_iterations=config.recon_max_iterations,
timeout_seconds=config.sub_agent_timeout_seconds,
tools=["list_files", "read_file", "search_code"],
knowledge_modules=["project_analysis"],
),
"analysis": AgentTypeConfig(
agent_type="analysis",
max_iterations=config.analysis_max_iterations,
timeout_seconds=config.sub_agent_timeout_seconds,
tools=[
"smart_scan", "pattern_match", "dataflow_analysis",
"read_file", "search_code", "semgrep_scan", "bandit_scan"
],
knowledge_modules=["sql_injection", "xss", "command_injection"],
),
"verification": AgentTypeConfig(
agent_type="verification",
max_iterations=config.verification_max_iterations,
timeout_seconds=config.sub_agent_timeout_seconds,
tools=["validate_vulnerability", "dataflow_analysis", "sandbox_execute"],
knowledge_modules=["vulnerability_verification"],
),
}
return agent_configs.get(
agent_type,
AgentTypeConfig(
agent_type=agent_type,
max_iterations=config.analysis_max_iterations,
timeout_seconds=config.sub_agent_timeout_seconds,
)
)
# ============ Configuration Validation ============
def validate_config() -> List[str]:
"""
Validate current configuration and return list of warnings.
Returns:
List of warning messages for potentially problematic settings.
"""
config = get_agent_config()
warnings = []
# Check for potentially problematic settings
if config.llm_max_retries > 5:
warnings.append(
f"llm_max_retries={config.llm_max_retries} is high, "
"may cause long delays on persistent failures"
)
if config.orchestrator_max_iterations > 50:
warnings.append(
f"orchestrator_max_iterations={config.orchestrator_max_iterations} "
"is very high, may lead to excessive LLM costs"
)
if config.max_file_size_bytes > 50 * 1024 * 1024:
warnings.append(
f"max_file_size_bytes={config.max_file_size_bytes} is very large, "
"may cause memory issues"
)
if not config.circuit_breaker_enabled:
warnings.append(
"circuit_breaker_enabled=False may cause cascading failures"
)
if not config.checkpoint_enabled:
warnings.append(
"checkpoint_enabled=False means state cannot be recovered on failure"
)
if config.log_llm_prompts or config.log_llm_responses:
warnings.append(
"LLM prompt/response logging enabled - may log sensitive data"
)
return warnings
# ============ Environment-Specific Presets ============
def apply_development_preset():
"""Apply development-friendly settings"""
os.environ.setdefault("AGENT_LOG_LEVEL", "DEBUG")
os.environ.setdefault("AGENT_LOG_LLM_PROMPTS", "true")
os.environ.setdefault("AGENT_LOG_TOOL_OUTPUTS", "true")
os.environ.setdefault("AGENT_CHECKPOINT_INTERVAL_ITERATIONS", "1")
get_agent_config.cache_clear()
def apply_production_preset():
"""Apply production-safe settings"""
os.environ.setdefault("AGENT_LOG_LEVEL", "INFO")
os.environ.setdefault("AGENT_LOG_LLM_PROMPTS", "false")
os.environ.setdefault("AGENT_LOG_LLM_RESPONSES", "false")
os.environ.setdefault("AGENT_CIRCUIT_BREAKER_ENABLED", "true")
os.environ.setdefault("AGENT_CHECKPOINT_ENABLED", "true")
get_agent_config.cache_clear()
def apply_testing_preset():
"""Apply testing settings with shorter timeouts"""
os.environ.setdefault("AGENT_LLM_TIMEOUT_SECONDS", "30")
os.environ.setdefault("AGENT_TOOL_TIMEOUT_SECONDS", "10")
os.environ.setdefault("AGENT_ORCHESTRATOR_MAX_ITERATIONS", "5")
os.environ.setdefault("AGENT_ANALYSIS_MAX_ITERATIONS", "10")
os.environ.setdefault("AGENT_CIRCUIT_BREAKER_ENABLED", "false")
get_agent_config.cache_clear()