CodeReview/backend/app/api/v1/endpoints/prompts.py

379 lines
12 KiB
Python

"""
提示词模板 API 端点
"""
import json
import time
from typing import Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func as sql_func
from app.api import deps
from app.db.session import get_db
from app.models.prompt_template import PromptTemplate
from app.models.user import User
from app.schemas.prompt_template import (
PromptTemplateCreate,
PromptTemplateUpdate,
PromptTemplateResponse,
PromptTemplateListResponse,
PromptTestRequest,
PromptTestResponse,
)
router = APIRouter()
@router.get("", response_model=PromptTemplateListResponse)
async def list_prompt_templates(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
template_type: Optional[str] = Query(None, description="模板类型过滤"),
is_active: Optional[bool] = Query(None, description="是否启用"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""获取提示词模板列表"""
query = select(PromptTemplate)
# 过滤条件:系统模板 + 当前用户创建的模板
query = query.where(
(PromptTemplate.is_system == True) |
(PromptTemplate.created_by == current_user.id)
)
if template_type:
query = query.where(PromptTemplate.template_type == template_type)
if is_active is not None:
query = query.where(PromptTemplate.is_active == is_active)
# 排序:系统模板优先,然后按排序权重和创建时间
query = query.order_by(
PromptTemplate.is_system.desc(),
PromptTemplate.is_default.desc(),
PromptTemplate.sort_order.asc(),
PromptTemplate.created_at.desc()
)
# 计数
count_query = select(sql_func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar()
# 分页
query = query.offset(skip).limit(limit)
result = await db.execute(query)
templates = result.scalars().all()
items = []
for t in templates:
variables = {}
if t.variables:
try:
variables = json.loads(t.variables)
except:
pass
items.append(PromptTemplateResponse(
id=t.id,
name=t.name,
description=t.description,
template_type=t.template_type,
content_zh=t.content_zh,
content_en=t.content_en,
variables=variables,
is_default=t.is_default,
is_system=t.is_system,
is_active=t.is_active,
sort_order=t.sort_order,
created_by=t.created_by,
created_at=t.created_at,
updated_at=t.updated_at,
))
return PromptTemplateListResponse(items=items, total=total)
@router.get("/{template_id}", response_model=PromptTemplateResponse)
async def get_prompt_template(
template_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""获取单个提示词模板"""
result = await db.execute(
select(PromptTemplate).where(PromptTemplate.id == template_id)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="模板不存在")
# 检查权限
if not template.is_system and template.created_by != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此模板")
variables = {}
if template.variables:
try:
variables = json.loads(template.variables)
except:
pass
return PromptTemplateResponse(
id=template.id,
name=template.name,
description=template.description,
template_type=template.template_type,
content_zh=template.content_zh,
content_en=template.content_en,
variables=variables,
is_default=template.is_default,
is_system=template.is_system,
is_active=template.is_active,
sort_order=template.sort_order,
created_by=template.created_by,
created_at=template.created_at,
updated_at=template.updated_at,
)
@router.post("", response_model=PromptTemplateResponse)
async def create_prompt_template(
template_in: PromptTemplateCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""创建提示词模板"""
template = PromptTemplate(
name=template_in.name,
description=template_in.description,
template_type=template_in.template_type,
content_zh=template_in.content_zh,
content_en=template_in.content_en,
variables=json.dumps(template_in.variables or {}),
is_active=template_in.is_active,
sort_order=template_in.sort_order,
is_system=False,
is_default=False,
created_by=current_user.id,
)
db.add(template)
await db.commit()
await db.refresh(template)
return PromptTemplateResponse(
id=template.id,
name=template.name,
description=template.description,
template_type=template.template_type,
content_zh=template.content_zh,
content_en=template.content_en,
variables=template_in.variables or {},
is_default=template.is_default,
is_system=template.is_system,
is_active=template.is_active,
sort_order=template.sort_order,
created_by=template.created_by,
created_at=template.created_at,
updated_at=template.updated_at,
)
@router.put("/{template_id}", response_model=PromptTemplateResponse)
async def update_prompt_template(
template_id: str,
template_in: PromptTemplateUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""更新提示词模板"""
result = await db.execute(
select(PromptTemplate).where(PromptTemplate.id == template_id)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="模板不存在")
# 系统模板不允许修改核心内容,只能修改启用状态
if template.is_system:
if template_in.is_active is not None:
template.is_active = template_in.is_active
else:
raise HTTPException(status_code=403, detail="系统模板不允许修改")
else:
# 检查权限
if template.created_by != current_user.id:
raise HTTPException(status_code=403, detail="无权修改此模板")
# 更新字段
update_data = template_in.dict(exclude_unset=True)
for field, value in update_data.items():
if field == "variables" and value is not None:
setattr(template, field, json.dumps(value))
elif field != "is_default": # 不允许用户设置默认
setattr(template, field, value)
await db.commit()
await db.refresh(template)
variables = {}
if template.variables:
try:
variables = json.loads(template.variables)
except:
pass
return PromptTemplateResponse(
id=template.id,
name=template.name,
description=template.description,
template_type=template.template_type,
content_zh=template.content_zh,
content_en=template.content_en,
variables=variables,
is_default=template.is_default,
is_system=template.is_system,
is_active=template.is_active,
sort_order=template.sort_order,
created_by=template.created_by,
created_at=template.created_at,
updated_at=template.updated_at,
)
@router.delete("/{template_id}")
async def delete_prompt_template(
template_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""删除提示词模板"""
result = await db.execute(
select(PromptTemplate).where(PromptTemplate.id == template_id)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="模板不存在")
if template.is_system:
raise HTTPException(status_code=403, detail="系统模板不允许删除")
if template.created_by != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此模板")
await db.delete(template)
await db.commit()
return {"message": "模板已删除"}
@router.post("/test", response_model=PromptTestResponse)
async def test_prompt_template(
request: PromptTestRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""测试提示词效果"""
from app.services.llm.service import LLMService
from app.models.user_config import UserConfig
from app.core.encryption import decrypt_sensitive_data
start_time = time.time()
try:
# 获取用户配置
user_config = {}
result_config = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result_config.scalar_one_or_none()
if config:
# 需要解密的敏感字段
SENSITIVE_LLM_FIELDS = [
'llmApiKey', 'geminiApiKey', 'openaiApiKey', 'claudeApiKey',
'qwenApiKey', 'deepseekApiKey', 'zhipuApiKey', 'moonshotApiKey',
'baiduApiKey', 'minimaxApiKey', 'doubaoApiKey'
]
llm_config = json.loads(config.llm_config) if config.llm_config else {}
for field in SENSITIVE_LLM_FIELDS:
if field in llm_config and llm_config[field]:
llm_config[field] = decrypt_sensitive_data(llm_config[field])
user_config = {'llmConfig': llm_config}
# 创建使用用户配置的LLM服务实例
llm_service = LLMService(user_config=user_config)
# 使用自定义提示词进行分析
result = await llm_service.analyze_code_with_custom_prompt(
code=request.code,
language=request.language,
custom_prompt=request.content,
output_language=request.output_language,
)
execution_time = time.time() - start_time
return PromptTestResponse(
success=True,
result=result,
execution_time=round(execution_time, 2),
)
except Exception as e:
execution_time = time.time() - start_time
import traceback
print(f"❌ 提示词测试失败: {e}")
print(traceback.format_exc())
return PromptTestResponse(
success=False,
error=str(e),
execution_time=round(execution_time, 2),
)
@router.post("/{template_id}/set-default")
async def set_default_template(
template_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""设置默认模板(仅管理员)"""
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="仅管理员可设置默认模板")
result = await db.execute(
select(PromptTemplate).where(PromptTemplate.id == template_id)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="模板不存在")
# 取消同类型的其他默认模板
await db.execute(
select(PromptTemplate)
.where(PromptTemplate.template_type == template.template_type)
.where(PromptTemplate.is_default == True)
)
same_type_defaults = (await db.execute(
select(PromptTemplate)
.where(PromptTemplate.template_type == template.template_type)
.where(PromptTemplate.is_default == True)
)).scalars().all()
for t in same_type_defaults:
t.is_default = False
# 设置新的默认模板
template.is_default = True
await db.commit()
return {"message": "已设置为默认模板"}