"""AI模型管理路由""" from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing import List from database import get_db from models.ai_model import AIModelConfig from models.user import User from schemas.ai_model import AIModelCreate, AIModelUpdate, AIModelResponse, ProviderInfo from routers.auth import get_admin_user, get_current_user router = APIRouter(prefix="/api/admin/models", tags=["AI模型管理"]) # 公开路由(登录用户可用,用于前台 AI 工具获取可选模型) public_router = APIRouter(prefix="/api/models", tags=["AI模型公开"]) # 预置的服务商和模型信息 PROVIDER_PRESETS = [ { "provider": "deepseek", "name": "DeepSeek", "default_base_url": "https://api.deepseek.com", "models": [ {"model_id": "deepseek-chat", "name": "DeepSeek-V3.2", "task_types": ["lightweight", "knowledge_base"], "description": "DeepSeek-V3.2 非思考模式,性价比极高"}, {"model_id": "deepseek-reasoner", "name": "DeepSeek-V3.2 思考", "task_types": ["reasoning", "knowledge_base"], "description": "DeepSeek-V3.2 思考模式,带推理链输出"}, ] }, { "provider": "openai", "name": "OpenAI", "default_base_url": "https://api.openai.com/v1", "models": [ {"model_id": "gpt-4o", "name": "GPT-4o", "task_types": ["multimodal", "reasoning"], "description": "多模态旗舰模型"}, {"model_id": "gpt-4o-mini", "name": "GPT-4o Mini", "task_types": ["lightweight"], "description": "轻量高效模型"}, {"model_id": "o3-mini", "name": "o3-mini", "task_types": ["reasoning"], "description": "推理增强模型"}, {"model_id": "text-embedding-3-large", "name": "Embedding Large", "task_types": ["embedding"], "description": "高维度文本嵌入模型"}, ] }, { "provider": "anthropic", "name": "Anthropic", "default_base_url": "https://api.anthropic.com", "models": [ {"model_id": "claude-sonnet-4-20250514", "name": "Claude Sonnet 4", "task_types": ["reasoning"], "description": "Claude最新推理模型"}, {"model_id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "task_types": ["lightweight"], "description": "快速轻量模型"}, ] }, { "provider": "google", "name": "Google Gemini", "default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai", "models": [ {"model_id": "gemini-2.5-pro-preview-06-05", "name": "Gemini 2.5 Pro", "task_types": ["multimodal", "reasoning"], "description": "多模态能力最强"}, {"model_id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash", "task_types": ["lightweight", "multimodal"], "description": "快速多模态模型"}, ] }, { "provider": "ark", "name": "火山方舟(豆包)", "default_base_url": "https://ark.cn-beijing.volces.com/api/v3", "models": [ {"model_id": "ep-20260411180700-z6nll", "name": "Doubao-Seed-2.0-pro", "task_types": ["reasoning", "lightweight", "knowledge_base"], "description": "豆包旗舰模型,支持联网搜索"}, {"model_id": "doubao-seedream-5-0-260128", "name": "Seedream 5.0 (图像生成)", "task_types": ["image"], "description": "豆包图像生成模型,支持文生图"}, ] }, ] TASK_TYPE_LABELS = { "multimodal": "多模态(图片/草图理解)", "reasoning": "推理分析(需求解读/架构分析)", "lightweight": "轻量任务(分类/标签)", "knowledge_base": "知识库分析(文档理解/问答)", "embedding": "向量嵌入", "image": "图像生成(AI配图/文生图)", } def _mask_api_key(key: str) -> str: """API Key脱敏""" if not key or len(key) < 8: return "****" if key else "" return key[:4] + "*" * (len(key) - 8) + key[-4:] def _to_response(model: AIModelConfig) -> dict: """转换为响应格式""" return { "id": model.id, "provider": model.provider, "provider_name": model.provider_name, "model_id": model.model_id, "model_name": model.model_name, "api_key_masked": _mask_api_key(model.api_key), "base_url": model.base_url, "task_type": model.task_type, "is_enabled": model.is_enabled, "is_default": model.is_default, "web_search_enabled": model.web_search_enabled, "description": model.description, "created_at": model.created_at, "updated_at": model.updated_at, } @router.get("/presets", response_model=List[ProviderInfo]) async def get_provider_presets(): """获取预置的服务商和模型列表""" return PROVIDER_PRESETS @router.get("/task-types") async def get_task_types(): """获取任务类型列表""" return TASK_TYPE_LABELS @router.get("", response_model=List[AIModelResponse]) async def list_models(provider: str = None, task_type: str = None, db: Session = Depends(get_db)): """获取所有已配置的模型""" query = db.query(AIModelConfig) if provider: query = query.filter(AIModelConfig.provider == provider) if task_type: query = query.filter(AIModelConfig.task_type == task_type) models = query.order_by(AIModelConfig.provider, AIModelConfig.created_at).all() return [_to_response(m) for m in models] @router.post("", response_model=AIModelResponse) async def create_model(data: AIModelCreate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)): """添加模型配置""" model = AIModelConfig( provider=data.provider, provider_name=data.provider_name, model_id=data.model_id, model_name=data.model_name, api_key=data.api_key, base_url=data.base_url, task_type=data.task_type, is_enabled=data.is_enabled, is_default=data.is_default, web_search_enabled=data.web_search_enabled, description=data.description, ) # 如果设为默认,取消同任务类型的其他默认 if data.is_default and data.task_type: db.query(AIModelConfig).filter( AIModelConfig.task_type == data.task_type, AIModelConfig.is_default == True ).update({"is_default": False}) db.add(model) db.commit() db.refresh(model) return _to_response(model) @router.put("/{model_id}", response_model=AIModelResponse) async def update_model(model_id: int, data: AIModelUpdate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)): """更新模型配置""" model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first() if not model: raise HTTPException(status_code=404, detail="模型配置不存在") update_data = data.dict(exclude_unset=True) # 如果API Key为空字符串,表示不修改 if "api_key" in update_data and update_data["api_key"] == "": del update_data["api_key"] # 如果设为默认,取消同任务类型的其他默认 if update_data.get("is_default") and (update_data.get("task_type") or model.task_type): task = update_data.get("task_type", model.task_type) db.query(AIModelConfig).filter( AIModelConfig.task_type == task, AIModelConfig.is_default == True, AIModelConfig.id != model_id ).update({"is_default": False}) for key, value in update_data.items(): setattr(model, key, value) db.commit() db.refresh(model) return _to_response(model) @router.delete("/{model_id}") async def delete_model(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)): """删除模型配置""" model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first() if not model: raise HTTPException(status_code=404, detail="模型配置不存在") db.delete(model) db.commit() return {"message": "删除成功"} @router.post("/init-defaults") async def init_default_models(db: Session = Depends(get_db), admin: User = Depends(get_admin_user)): """初始化默认模型配置(仅当数据库为空时)""" count = db.query(AIModelConfig).count() if count > 0: return {"message": f"已有 {count} 条配置,跳过初始化", "count": count} defaults = [ AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-chat", model_name="DeepSeek-V3", task_type="reasoning", is_default=True, is_enabled=True, base_url="https://api.deepseek.com/v1", description="DeepSeek最新对话模型,性价比极高"), AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-reasoner", model_name="DeepSeek-R1", task_type="", is_enabled=True, base_url="https://api.deepseek.com/v1", description="深度推理模型,适合复杂逻辑分析"), AIModelConfig(provider="openai", provider_name="OpenAI", model_id="gpt-4o-mini", model_name="GPT-4o Mini", task_type="lightweight", is_default=True, is_enabled=True, description="轻量高效模型"), AIModelConfig(provider="google", provider_name="Google Gemini", model_id="gemini-2.5-pro-preview-06-05", model_name="Gemini 2.5 Pro", task_type="multimodal", is_default=True, is_enabled=True, base_url="https://generativelanguage.googleapis.com/v1beta/openai", description="多模态能力最强"), ] db.add_all(defaults) db.commit() return {"message": f"已初始化 {len(defaults)} 条默认配置", "count": len(defaults)} @router.post("/{model_id}/test") async def test_model_connection(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)): """测试模型连接是否正常""" model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first() if not model: raise HTTPException(status_code=404, detail="模型配置不存在") if not model.api_key: return {"success": False, "message": "未配置 API Key"} try: import httpx headers = {"Authorization": f"Bearer {model.api_key}", "Content-Type": "application/json"} base_url = model.base_url or f"https://api.{model.provider}.com/v1" if model.provider == "anthropic": headers = {"x-api-key": model.api_key, "anthropic-version": "2023-06-01", "Content-Type": "application/json"} url = f"{base_url}/v1/messages" payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]} else: url = f"{base_url}/chat/completions" payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]} async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.post(url, json=payload, headers=headers) if resp.status_code == 200: return {"success": True, "message": "连接成功"} else: return {"success": False, "message": f"HTTP {resp.status_code}: {resp.text[:200]}"} except Exception as e: return {"success": False, "message": f"连接失败: {str(e)}"} # ===== 公开接口 ===== @public_router.get("/available") async def get_available_models( task_type: str = "", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """获取可用模型列表(登录用户可调用,不返回 API Key)""" query = db.query(AIModelConfig).filter( AIModelConfig.is_enabled == True, AIModelConfig.api_key != "", ) if task_type: query = query.filter(AIModelConfig.task_type == task_type) models = query.order_by(AIModelConfig.is_default.desc(), AIModelConfig.created_at).all() return [ { "id": m.id, "provider": m.provider, "provider_name": m.provider_name, "model_id": m.model_id, "model_name": m.model_name, "task_type": m.task_type, "is_default": m.is_default, "web_search_enabled": m.web_search_enabled, "web_search_count": m.web_search_count or 5, "description": m.description, } for m in models ]