Files
bianchengshequ/backend/routers/ai_models.py

286 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI模型管理路由"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from database import get_db
from models.ai_model import AIModelConfig
from models.user import User
from schemas.ai_model import AIModelCreate, AIModelUpdate, AIModelResponse, ProviderInfo
from routers.auth import get_admin_user, get_current_user
router = APIRouter(prefix="/api/admin/models", tags=["AI模型管理"])
# 公开路由(登录用户可用,用于前台 AI 工具获取可选模型)
public_router = APIRouter(prefix="/api/models", tags=["AI模型公开"])
# 预置的服务商和模型信息
PROVIDER_PRESETS = [
{
"provider": "deepseek",
"name": "DeepSeek",
"default_base_url": "https://api.deepseek.com",
"models": [
{"model_id": "deepseek-chat", "name": "DeepSeek-V3.2", "task_types": ["lightweight", "knowledge_base"], "description": "DeepSeek-V3.2 非思考模式,性价比极高"},
{"model_id": "deepseek-reasoner", "name": "DeepSeek-V3.2 思考", "task_types": ["reasoning", "knowledge_base"], "description": "DeepSeek-V3.2 思考模式,带推理链输出"},
]
},
{
"provider": "openai",
"name": "OpenAI",
"default_base_url": "https://api.openai.com/v1",
"models": [
{"model_id": "gpt-4o", "name": "GPT-4o", "task_types": ["multimodal", "reasoning"], "description": "多模态旗舰模型"},
{"model_id": "gpt-4o-mini", "name": "GPT-4o Mini", "task_types": ["lightweight"], "description": "轻量高效模型"},
{"model_id": "o3-mini", "name": "o3-mini", "task_types": ["reasoning"], "description": "推理增强模型"},
{"model_id": "text-embedding-3-large", "name": "Embedding Large", "task_types": ["embedding"], "description": "高维度文本嵌入模型"},
]
},
{
"provider": "anthropic",
"name": "Anthropic",
"default_base_url": "https://api.anthropic.com",
"models": [
{"model_id": "claude-sonnet-4-20250514", "name": "Claude Sonnet 4", "task_types": ["reasoning"], "description": "Claude最新推理模型"},
{"model_id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "task_types": ["lightweight"], "description": "快速轻量模型"},
]
},
{
"provider": "google",
"name": "Google Gemini",
"default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
"models": [
{"model_id": "gemini-2.5-pro-preview-06-05", "name": "Gemini 2.5 Pro", "task_types": ["multimodal", "reasoning"], "description": "多模态能力最强"},
{"model_id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash", "task_types": ["lightweight", "multimodal"], "description": "快速多模态模型"},
]
},
{
"provider": "ark",
"name": "火山方舟(豆包)",
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
"models": [
{"model_id": "ep-20260411180700-z6nll", "name": "Doubao-Seed-2.0-pro", "task_types": ["reasoning", "lightweight", "knowledge_base"], "description": "豆包旗舰模型,支持联网搜索"},
{"model_id": "doubao-seedream-5-0-260128", "name": "Seedream 5.0 (图像生成)", "task_types": ["image"], "description": "豆包图像生成模型,支持文生图"},
]
},
]
TASK_TYPE_LABELS = {
"multimodal": "多模态(图片/草图理解)",
"reasoning": "推理分析(需求解读/架构分析)",
"lightweight": "轻量任务(分类/标签)",
"knowledge_base": "知识库分析(文档理解/问答)",
"embedding": "向量嵌入",
"image": "图像生成AI配图/文生图)",
}
def _mask_api_key(key: str) -> str:
"""API Key脱敏"""
if not key or len(key) < 8:
return "****" if key else ""
return key[:4] + "*" * (len(key) - 8) + key[-4:]
def _to_response(model: AIModelConfig) -> dict:
"""转换为响应格式"""
return {
"id": model.id,
"provider": model.provider,
"provider_name": model.provider_name,
"model_id": model.model_id,
"model_name": model.model_name,
"api_key_masked": _mask_api_key(model.api_key),
"base_url": model.base_url,
"task_type": model.task_type,
"is_enabled": model.is_enabled,
"is_default": model.is_default,
"web_search_enabled": model.web_search_enabled,
"description": model.description,
"created_at": model.created_at,
"updated_at": model.updated_at,
}
@router.get("/presets", response_model=List[ProviderInfo])
async def get_provider_presets():
"""获取预置的服务商和模型列表"""
return PROVIDER_PRESETS
@router.get("/task-types")
async def get_task_types():
"""获取任务类型列表"""
return TASK_TYPE_LABELS
@router.get("", response_model=List[AIModelResponse])
async def list_models(provider: str = None, task_type: str = None, db: Session = Depends(get_db)):
"""获取所有已配置的模型"""
query = db.query(AIModelConfig)
if provider:
query = query.filter(AIModelConfig.provider == provider)
if task_type:
query = query.filter(AIModelConfig.task_type == task_type)
models = query.order_by(AIModelConfig.provider, AIModelConfig.created_at).all()
return [_to_response(m) for m in models]
@router.post("", response_model=AIModelResponse)
async def create_model(data: AIModelCreate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""添加模型配置"""
model = AIModelConfig(
provider=data.provider,
provider_name=data.provider_name,
model_id=data.model_id,
model_name=data.model_name,
api_key=data.api_key,
base_url=data.base_url,
task_type=data.task_type,
is_enabled=data.is_enabled,
is_default=data.is_default,
web_search_enabled=data.web_search_enabled,
description=data.description,
)
# 如果设为默认,取消同任务类型的其他默认
if data.is_default and data.task_type:
db.query(AIModelConfig).filter(
AIModelConfig.task_type == data.task_type,
AIModelConfig.is_default == True
).update({"is_default": False})
db.add(model)
db.commit()
db.refresh(model)
return _to_response(model)
@router.put("/{model_id}", response_model=AIModelResponse)
async def update_model(model_id: int, data: AIModelUpdate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""更新模型配置"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="模型配置不存在")
update_data = data.dict(exclude_unset=True)
# 如果API Key为空字符串表示不修改
if "api_key" in update_data and update_data["api_key"] == "":
del update_data["api_key"]
# 如果设为默认,取消同任务类型的其他默认
if update_data.get("is_default") and (update_data.get("task_type") or model.task_type):
task = update_data.get("task_type", model.task_type)
db.query(AIModelConfig).filter(
AIModelConfig.task_type == task,
AIModelConfig.is_default == True,
AIModelConfig.id != model_id
).update({"is_default": False})
for key, value in update_data.items():
setattr(model, key, value)
db.commit()
db.refresh(model)
return _to_response(model)
@router.delete("/{model_id}")
async def delete_model(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""删除模型配置"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="模型配置不存在")
db.delete(model)
db.commit()
return {"message": "删除成功"}
@router.post("/init-defaults")
async def init_default_models(db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""初始化默认模型配置(仅当数据库为空时)"""
count = db.query(AIModelConfig).count()
if count > 0:
return {"message": f"已有 {count} 条配置,跳过初始化", "count": count}
defaults = [
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-chat",
model_name="DeepSeek-V3", task_type="reasoning", is_default=True, is_enabled=True,
base_url="https://api.deepseek.com/v1", description="DeepSeek最新对话模型性价比极高"),
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-reasoner",
model_name="DeepSeek-R1", task_type="", is_enabled=True,
base_url="https://api.deepseek.com/v1", description="深度推理模型,适合复杂逻辑分析"),
AIModelConfig(provider="openai", provider_name="OpenAI", model_id="gpt-4o-mini",
model_name="GPT-4o Mini", task_type="lightweight", is_default=True, is_enabled=True,
description="轻量高效模型"),
AIModelConfig(provider="google", provider_name="Google Gemini", model_id="gemini-2.5-pro-preview-06-05",
model_name="Gemini 2.5 Pro", task_type="multimodal", is_default=True, is_enabled=True,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
description="多模态能力最强"),
]
db.add_all(defaults)
db.commit()
return {"message": f"已初始化 {len(defaults)} 条默认配置", "count": len(defaults)}
@router.post("/{model_id}/test")
async def test_model_connection(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""测试模型连接是否正常"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="模型配置不存在")
if not model.api_key:
return {"success": False, "message": "未配置 API Key"}
try:
import httpx
headers = {"Authorization": f"Bearer {model.api_key}", "Content-Type": "application/json"}
base_url = model.base_url or f"https://api.{model.provider}.com/v1"
if model.provider == "anthropic":
headers = {"x-api-key": model.api_key, "anthropic-version": "2023-06-01", "Content-Type": "application/json"}
url = f"{base_url}/v1/messages"
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
else:
url = f"{base_url}/chat/completions"
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code == 200:
return {"success": True, "message": "连接成功"}
else:
return {"success": False, "message": f"HTTP {resp.status_code}: {resp.text[:200]}"}
except Exception as e:
return {"success": False, "message": f"连接失败: {str(e)}"}
# ===== 公开接口 =====
@public_router.get("/available")
async def get_available_models(
task_type: str = "",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取可用模型列表(登录用户可调用,不返回 API Key"""
query = db.query(AIModelConfig).filter(
AIModelConfig.is_enabled == True,
AIModelConfig.api_key != "",
)
if task_type:
query = query.filter(AIModelConfig.task_type == task_type)
models = query.order_by(AIModelConfig.is_default.desc(), AIModelConfig.created_at).all()
return [
{
"id": m.id,
"provider": m.provider,
"provider_name": m.provider_name,
"model_id": m.model_id,
"model_name": m.model_name,
"task_type": m.task_type,
"is_default": m.is_default,
"web_search_enabled": m.web_search_enabled,
"web_search_count": m.web_search_count or 5,
"description": m.description,
}
for m in models
]