286 lines
12 KiB
Python
286 lines
12 KiB
Python
"""AI模型管理路由"""
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from sqlalchemy.orm import Session
|
||
from typing import List
|
||
from database import get_db
|
||
from models.ai_model import AIModelConfig
|
||
from models.user import User
|
||
from schemas.ai_model import AIModelCreate, AIModelUpdate, AIModelResponse, ProviderInfo
|
||
from routers.auth import get_admin_user, get_current_user
|
||
|
||
router = APIRouter(prefix="/api/admin/models", tags=["AI模型管理"])
|
||
|
||
# 公开路由(登录用户可用,用于前台 AI 工具获取可选模型)
|
||
public_router = APIRouter(prefix="/api/models", tags=["AI模型公开"])
|
||
|
||
# 预置的服务商和模型信息
|
||
PROVIDER_PRESETS = [
|
||
{
|
||
"provider": "deepseek",
|
||
"name": "DeepSeek",
|
||
"default_base_url": "https://api.deepseek.com",
|
||
"models": [
|
||
{"model_id": "deepseek-chat", "name": "DeepSeek-V3.2", "task_types": ["lightweight", "knowledge_base"], "description": "DeepSeek-V3.2 非思考模式,性价比极高"},
|
||
{"model_id": "deepseek-reasoner", "name": "DeepSeek-V3.2 思考", "task_types": ["reasoning", "knowledge_base"], "description": "DeepSeek-V3.2 思考模式,带推理链输出"},
|
||
]
|
||
},
|
||
{
|
||
"provider": "openai",
|
||
"name": "OpenAI",
|
||
"default_base_url": "https://api.openai.com/v1",
|
||
"models": [
|
||
{"model_id": "gpt-4o", "name": "GPT-4o", "task_types": ["multimodal", "reasoning"], "description": "多模态旗舰模型"},
|
||
{"model_id": "gpt-4o-mini", "name": "GPT-4o Mini", "task_types": ["lightweight"], "description": "轻量高效模型"},
|
||
{"model_id": "o3-mini", "name": "o3-mini", "task_types": ["reasoning"], "description": "推理增强模型"},
|
||
{"model_id": "text-embedding-3-large", "name": "Embedding Large", "task_types": ["embedding"], "description": "高维度文本嵌入模型"},
|
||
]
|
||
},
|
||
{
|
||
"provider": "anthropic",
|
||
"name": "Anthropic",
|
||
"default_base_url": "https://api.anthropic.com",
|
||
"models": [
|
||
{"model_id": "claude-sonnet-4-20250514", "name": "Claude Sonnet 4", "task_types": ["reasoning"], "description": "Claude最新推理模型"},
|
||
{"model_id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "task_types": ["lightweight"], "description": "快速轻量模型"},
|
||
]
|
||
},
|
||
{
|
||
"provider": "google",
|
||
"name": "Google Gemini",
|
||
"default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||
"models": [
|
||
{"model_id": "gemini-2.5-pro-preview-06-05", "name": "Gemini 2.5 Pro", "task_types": ["multimodal", "reasoning"], "description": "多模态能力最强"},
|
||
{"model_id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash", "task_types": ["lightweight", "multimodal"], "description": "快速多模态模型"},
|
||
]
|
||
},
|
||
{
|
||
"provider": "ark",
|
||
"name": "火山方舟(豆包)",
|
||
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||
"models": [
|
||
{"model_id": "ep-20260411180700-z6nll", "name": "Doubao-Seed-2.0-pro", "task_types": ["reasoning", "lightweight", "knowledge_base"], "description": "豆包旗舰模型,支持联网搜索"},
|
||
{"model_id": "doubao-seedream-5-0-260128", "name": "Seedream 5.0 (图像生成)", "task_types": ["image"], "description": "豆包图像生成模型,支持文生图"},
|
||
]
|
||
},
|
||
]
|
||
|
||
TASK_TYPE_LABELS = {
|
||
"multimodal": "多模态(图片/草图理解)",
|
||
"reasoning": "推理分析(需求解读/架构分析)",
|
||
"lightweight": "轻量任务(分类/标签)",
|
||
"knowledge_base": "知识库分析(文档理解/问答)",
|
||
"embedding": "向量嵌入",
|
||
"image": "图像生成(AI配图/文生图)",
|
||
}
|
||
|
||
|
||
def _mask_api_key(key: str) -> str:
|
||
"""API Key脱敏"""
|
||
if not key or len(key) < 8:
|
||
return "****" if key else ""
|
||
return key[:4] + "*" * (len(key) - 8) + key[-4:]
|
||
|
||
|
||
def _to_response(model: AIModelConfig) -> dict:
|
||
"""转换为响应格式"""
|
||
return {
|
||
"id": model.id,
|
||
"provider": model.provider,
|
||
"provider_name": model.provider_name,
|
||
"model_id": model.model_id,
|
||
"model_name": model.model_name,
|
||
"api_key_masked": _mask_api_key(model.api_key),
|
||
"base_url": model.base_url,
|
||
"task_type": model.task_type,
|
||
"is_enabled": model.is_enabled,
|
||
"is_default": model.is_default,
|
||
"web_search_enabled": model.web_search_enabled,
|
||
"description": model.description,
|
||
"created_at": model.created_at,
|
||
"updated_at": model.updated_at,
|
||
}
|
||
|
||
|
||
@router.get("/presets", response_model=List[ProviderInfo])
|
||
async def get_provider_presets():
|
||
"""获取预置的服务商和模型列表"""
|
||
return PROVIDER_PRESETS
|
||
|
||
|
||
@router.get("/task-types")
|
||
async def get_task_types():
|
||
"""获取任务类型列表"""
|
||
return TASK_TYPE_LABELS
|
||
|
||
|
||
@router.get("", response_model=List[AIModelResponse])
|
||
async def list_models(provider: str = None, task_type: str = None, db: Session = Depends(get_db)):
|
||
"""获取所有已配置的模型"""
|
||
query = db.query(AIModelConfig)
|
||
if provider:
|
||
query = query.filter(AIModelConfig.provider == provider)
|
||
if task_type:
|
||
query = query.filter(AIModelConfig.task_type == task_type)
|
||
models = query.order_by(AIModelConfig.provider, AIModelConfig.created_at).all()
|
||
return [_to_response(m) for m in models]
|
||
|
||
|
||
@router.post("", response_model=AIModelResponse)
|
||
async def create_model(data: AIModelCreate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||
"""添加模型配置"""
|
||
model = AIModelConfig(
|
||
provider=data.provider,
|
||
provider_name=data.provider_name,
|
||
model_id=data.model_id,
|
||
model_name=data.model_name,
|
||
api_key=data.api_key,
|
||
base_url=data.base_url,
|
||
task_type=data.task_type,
|
||
is_enabled=data.is_enabled,
|
||
is_default=data.is_default,
|
||
web_search_enabled=data.web_search_enabled,
|
||
description=data.description,
|
||
)
|
||
# 如果设为默认,取消同任务类型的其他默认
|
||
if data.is_default and data.task_type:
|
||
db.query(AIModelConfig).filter(
|
||
AIModelConfig.task_type == data.task_type,
|
||
AIModelConfig.is_default == True
|
||
).update({"is_default": False})
|
||
db.add(model)
|
||
db.commit()
|
||
db.refresh(model)
|
||
return _to_response(model)
|
||
|
||
|
||
@router.put("/{model_id}", response_model=AIModelResponse)
|
||
async def update_model(model_id: int, data: AIModelUpdate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||
"""更新模型配置"""
|
||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="模型配置不存在")
|
||
|
||
update_data = data.dict(exclude_unset=True)
|
||
|
||
# 如果API Key为空字符串,表示不修改
|
||
if "api_key" in update_data and update_data["api_key"] == "":
|
||
del update_data["api_key"]
|
||
|
||
# 如果设为默认,取消同任务类型的其他默认
|
||
if update_data.get("is_default") and (update_data.get("task_type") or model.task_type):
|
||
task = update_data.get("task_type", model.task_type)
|
||
db.query(AIModelConfig).filter(
|
||
AIModelConfig.task_type == task,
|
||
AIModelConfig.is_default == True,
|
||
AIModelConfig.id != model_id
|
||
).update({"is_default": False})
|
||
|
||
for key, value in update_data.items():
|
||
setattr(model, key, value)
|
||
db.commit()
|
||
db.refresh(model)
|
||
return _to_response(model)
|
||
|
||
|
||
@router.delete("/{model_id}")
|
||
async def delete_model(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||
"""删除模型配置"""
|
||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="模型配置不存在")
|
||
db.delete(model)
|
||
db.commit()
|
||
return {"message": "删除成功"}
|
||
|
||
|
||
@router.post("/init-defaults")
|
||
async def init_default_models(db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||
"""初始化默认模型配置(仅当数据库为空时)"""
|
||
count = db.query(AIModelConfig).count()
|
||
if count > 0:
|
||
return {"message": f"已有 {count} 条配置,跳过初始化", "count": count}
|
||
|
||
defaults = [
|
||
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-chat",
|
||
model_name="DeepSeek-V3", task_type="reasoning", is_default=True, is_enabled=True,
|
||
base_url="https://api.deepseek.com/v1", description="DeepSeek最新对话模型,性价比极高"),
|
||
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-reasoner",
|
||
model_name="DeepSeek-R1", task_type="", is_enabled=True,
|
||
base_url="https://api.deepseek.com/v1", description="深度推理模型,适合复杂逻辑分析"),
|
||
AIModelConfig(provider="openai", provider_name="OpenAI", model_id="gpt-4o-mini",
|
||
model_name="GPT-4o Mini", task_type="lightweight", is_default=True, is_enabled=True,
|
||
description="轻量高效模型"),
|
||
AIModelConfig(provider="google", provider_name="Google Gemini", model_id="gemini-2.5-pro-preview-06-05",
|
||
model_name="Gemini 2.5 Pro", task_type="multimodal", is_default=True, is_enabled=True,
|
||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||
description="多模态能力最强"),
|
||
]
|
||
db.add_all(defaults)
|
||
db.commit()
|
||
return {"message": f"已初始化 {len(defaults)} 条默认配置", "count": len(defaults)}
|
||
|
||
|
||
@router.post("/{model_id}/test")
|
||
async def test_model_connection(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||
"""测试模型连接是否正常"""
|
||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="模型配置不存在")
|
||
if not model.api_key:
|
||
return {"success": False, "message": "未配置 API Key"}
|
||
|
||
try:
|
||
import httpx
|
||
headers = {"Authorization": f"Bearer {model.api_key}", "Content-Type": "application/json"}
|
||
base_url = model.base_url or f"https://api.{model.provider}.com/v1"
|
||
|
||
if model.provider == "anthropic":
|
||
headers = {"x-api-key": model.api_key, "anthropic-version": "2023-06-01", "Content-Type": "application/json"}
|
||
url = f"{base_url}/v1/messages"
|
||
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
|
||
else:
|
||
url = f"{base_url}/chat/completions"
|
||
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
|
||
|
||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||
resp = await client.post(url, json=payload, headers=headers)
|
||
if resp.status_code == 200:
|
||
return {"success": True, "message": "连接成功"}
|
||
else:
|
||
return {"success": False, "message": f"HTTP {resp.status_code}: {resp.text[:200]}"}
|
||
except Exception as e:
|
||
return {"success": False, "message": f"连接失败: {str(e)}"}
|
||
|
||
|
||
# ===== 公开接口 =====
|
||
|
||
@public_router.get("/available")
|
||
async def get_available_models(
|
||
task_type: str = "",
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""获取可用模型列表(登录用户可调用,不返回 API Key)"""
|
||
query = db.query(AIModelConfig).filter(
|
||
AIModelConfig.is_enabled == True,
|
||
AIModelConfig.api_key != "",
|
||
)
|
||
if task_type:
|
||
query = query.filter(AIModelConfig.task_type == task_type)
|
||
models = query.order_by(AIModelConfig.is_default.desc(), AIModelConfig.created_at).all()
|
||
return [
|
||
{
|
||
"id": m.id,
|
||
"provider": m.provider,
|
||
"provider_name": m.provider_name,
|
||
"model_id": m.model_id,
|
||
"model_name": m.model_name,
|
||
"task_type": m.task_type,
|
||
"is_default": m.is_default,
|
||
"web_search_enabled": m.web_search_enabled,
|
||
"web_search_count": m.web_search_count or 5,
|
||
"description": m.description,
|
||
}
|
||
for m in models
|
||
]
|