"""统一大模型调用服务 - 支持多模型路由和流式输出""" import json import httpx import asyncio from typing import AsyncGenerator, List, Optional, Union from openai import AsyncOpenAI from anthropic import AsyncAnthropic from config import ( MODEL_CONFIG, OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY, DEEPSEEK_API_KEY, ARK_API_KEY, ARK_BASE_URL, ) def _get_db_model_config(task_type: str, model_config_id: int = None): """从数据库获取指定任务类型的默认模型配置,或指定 ID 的模型""" try: from database import SessionLocal from models.ai_model import AIModelConfig db = SessionLocal() try: # 如果指定了模型 ID,直接用该模型 if model_config_id: model = db.query(AIModelConfig).filter( AIModelConfig.id == model_config_id, AIModelConfig.is_enabled == True, ).first() if model and model.api_key: return { "provider": model.provider, "model": model.model_id, "api_key": model.api_key, "base_url": model.base_url, "web_search_enabled": model.web_search_enabled, "web_search_count": model.web_search_count or 5, } # 否则找默认模型 model = db.query(AIModelConfig).filter( AIModelConfig.task_type == task_type, AIModelConfig.is_default == True, AIModelConfig.is_enabled == True, ).first() if model and model.api_key: return { "provider": model.provider, "model": model.model_id, "api_key": model.api_key, "base_url": model.base_url, "web_search_enabled": model.web_search_enabled, "web_search_count": model.web_search_count or 5, } # 没有默认的,找任意一个启用且有Key的 model = db.query(AIModelConfig).filter( AIModelConfig.task_type == task_type, AIModelConfig.is_enabled == True, AIModelConfig.api_key != "", ).first() if model: return { "provider": model.provider, "model": model.model_id, "api_key": model.api_key, "base_url": model.base_url, "web_search_enabled": model.web_search_enabled, "web_search_count": model.web_search_count or 5, } finally: db.close() except Exception: pass return None class AIService: """统一AI服务,根据任务类型路由到不同大模型""" def __init__(self): # OpenAI客户端(也用于DeepSeek等兼容API) if OPENAI_API_KEY: self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY) else: self.openai_client = None # Anthropic客户端 if ANTHROPIC_API_KEY: self.anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY) else: self.anthropic_client = None # DeepSeek客户端(兼容OpenAI接口) if DEEPSEEK_API_KEY: self.deepseek_client = AsyncOpenAI( api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com/v1", ) else: self.deepseek_client = None def _get_client_for_provider(self, provider: str, api_key: str, base_url: str = ""): """根据provider动态创建客户端""" if provider == "anthropic": return AsyncAnthropic(api_key=api_key) # openai/deepseek/google 都用OpenAI兼容接口 kwargs = {"api_key": api_key} if base_url: kwargs["base_url"] = base_url return AsyncOpenAI(**kwargs) async def chat( self, task_type: str, messages: List[dict], system_prompt: str = "", stream: bool = False, model_config_id: int = None, ) -> Union[str, AsyncGenerator[str, None]]: """ 统一对话接口 参数: task_type: 任务类型 (multimodal/reasoning/lightweight) messages: 消息列表 [{"role": "user", "content": "..."}] system_prompt: 系统提示词 stream: 是否流式输出 model_config_id: 指定模型配置ID(可选,不传则用默认) """ # 优先从数据库读取模型配置 db_config = _get_db_model_config(task_type, model_config_id) if db_config: provider = db_config["provider"] model = db_config["model"] api_key = db_config["api_key"] base_url = db_config["base_url"] web_search = db_config.get("web_search_enabled", False) web_search_count = db_config.get("web_search_count", 5) # 火山方舟/豆包 + 联网搜索(开启后自动使用,失败则降级到普通调用) if provider == "ark" and web_search: try: return await self._chat_ark_web_search(api_key, base_url, model, messages, system_prompt, stream, web_search_count) except Exception: # 联网搜索调用失败,降级到普通调用 pass # 火山方舟/豆包 不带联网搜索(OpenAI 兼容接口) if provider == "ark": kwargs = {"api_key": api_key} if base_url: kwargs["base_url"] = base_url else: kwargs["base_url"] = ARK_BASE_URL client = AsyncOpenAI(**kwargs) return await self._chat_openai(client, model, messages, system_prompt, stream) if provider == "anthropic": client = AsyncAnthropic(api_key=api_key) return await self._chat_anthropic_with_client(client, model, messages, system_prompt, stream) else: kwargs = {"api_key": api_key} if base_url: kwargs["base_url"] = base_url client = AsyncOpenAI(**kwargs) # deepseek-reasoner 需要特殊处理 if model == "deepseek-reasoner": return await self._chat_deepseek_reasoner(client, model, messages, system_prompt, stream) return await self._chat_openai(client, model, messages, system_prompt, stream) # 回退到 .env 配置 config = MODEL_CONFIG.get(task_type, MODEL_CONFIG["reasoning"]) provider = config["provider"] model = config["model"] if provider == "anthropic" and self.anthropic_client: return await self._chat_anthropic(model, messages, system_prompt, stream) elif provider == "openai" and self.openai_client: return await self._chat_openai(self.openai_client, model, messages, system_prompt, stream) elif provider == "deepseek" and self.deepseek_client: return await self._chat_openai(self.deepseek_client, model, messages, system_prompt, stream) elif provider == "google": if GOOGLE_API_KEY: google_client = AsyncOpenAI( api_key=GOOGLE_API_KEY, base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) return await self._chat_openai(google_client, model, messages, system_prompt, stream) # 降级 if self.deepseek_client: return await self._chat_openai(self.deepseek_client, "deepseek-chat", messages, system_prompt, stream) if self.openai_client: return await self._chat_openai(self.openai_client, "gpt-4o-mini", messages, system_prompt, stream) if self.anthropic_client: return await self._chat_anthropic("claude-sonnet-4-20250514", messages, system_prompt, stream) return "未配置任何AI模型,请到「模型管理」页面配置模型和API Key。" async def _chat_ark_web_search( self, api_key: str, base_url: str, model: str, messages: List[dict], system_prompt: str, stream: bool, search_count: int = 5, ) -> Union[str, AsyncGenerator[str, None]]: """火山方舟 + 联网搜索(使用 httpx 直接调用,因 web_search 是非标准 tools 类型)""" url = f"{base_url or ARK_BASE_URL}/chat/completions" full_messages = [] if system_prompt: full_messages.append({"role": "system", "content": system_prompt}) full_messages.extend(messages) # 限制搜索条数范围 1-50 search_count = max(1, min(50, search_count or 5)) payload = { "model": model, "messages": full_messages, "stream": stream, "tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": search_count}}], } headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } if not stream: async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post(url, json=payload, headers=headers) if resp.status_code != 200: # 联网搜索调用失败,抛出异常以便降级到普通调用 raise Exception(f"API调用失败 ({resp.status_code}): {resp.text[:200]}") data = resp.json() return data.get("choices", [{}])[0].get("message", {}).get("content", "") else: # 流式调用时,先发一个预检测请求确认模型支持 web_search async with httpx.AsyncClient(timeout=10.0) as client: test_payload = { "model": model, "messages": [{"role": "user", "content": "test"}], "stream": False, "max_tokens": 1, "tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": 1}}], } try: resp = await client.post(url, json=test_payload, headers=headers) if resp.status_code == 400: # 模型不支持 web_search,抛出异常以便降级 raise Exception("模型不支持联网搜索") except httpx.TimeoutException: pass # 超时不影响,继续尝试流式调用 return self._stream_ark_web_search(url, payload, headers) async def _stream_ark_web_search( self, url: str, payload: dict, headers: dict, ) -> AsyncGenerator[str, None]: """火山方舟联网搜索流式输出""" async with httpx.AsyncClient(timeout=120.0) as client: async with client.stream("POST", url, json=payload, headers=headers) as resp: if resp.status_code != 200: error_body = await resp.aread() yield f"API调用失败 ({resp.status_code}): {error_body.decode()[:200]}" return buffer = "" async for chunk in resp.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line or not line.startswith("data: "): continue data_str = line[6:] if data_str == "[DONE]": continue try: data = json.loads(data_str) choices = data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: yield content except json.JSONDecodeError: pass async def _chat_openai( self, client: AsyncOpenAI, model: str, messages: List[dict], system_prompt: str, stream: bool, ) -> Union[str, AsyncGenerator[str, None]]: """OpenAI兼容接口调用""" full_messages = [] if system_prompt: full_messages.append({"role": "system", "content": system_prompt}) full_messages.extend(messages) if stream: return self._stream_openai(client, model, full_messages) else: response = await client.chat.completions.create( model=model, messages=full_messages, temperature=0.7, max_tokens=4096, ) return response.choices[0].message.content async def _stream_openai( self, client: AsyncOpenAI, model: str, messages: List[dict], ) -> AsyncGenerator[str, None]: """OpenAI流式输出""" response = await client.chat.completions.create( model=model, messages=messages, temperature=0.7, max_tokens=4096, stream=True, ) async for chunk in response: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content async def _chat_deepseek_reasoner( self, client: AsyncOpenAI, model: str, messages: List[dict], system_prompt: str, stream: bool, ) -> Union[str, AsyncGenerator[str, None]]: """DeepSeek Reasoner (思考模式) 专用调用 注意:deepseek-reasoner 不支持 temperature/top_p/system 等参数 输出包含 reasoning_content(思考过程)和 content(最终回答) """ # reasoner 不支持 system role,将 system prompt 合并到第一条用户消息 full_messages = [] for msg in messages: full_messages.append(msg) if system_prompt and full_messages: first_user = None for m in full_messages: if m["role"] == "user": first_user = m break if first_user: first_user["content"] = f"[指令] {system_prompt}\n\n[用户输入] {first_user['content']}" if stream: return self._stream_deepseek_reasoner(client, model, full_messages) else: response = await client.chat.completions.create( model=model, messages=full_messages, max_tokens=8192, ) reasoning = getattr(response.choices[0].message, 'reasoning_content', '') or '' content = response.choices[0].message.content or '' if reasoning: return f"\n{reasoning}\n\n\n{content}" return content async def _stream_deepseek_reasoner( self, client: AsyncOpenAI, model: str, messages: List[dict], ) -> AsyncGenerator[str, None]: """DeepSeek Reasoner 流式输出 - 包含思考过程和最终回答""" response = await client.chat.completions.create( model=model, messages=messages, max_tokens=8192, stream=True, ) in_reasoning = False reasoning_started = False async for chunk in response: delta = chunk.choices[0].delta # 思考过程 reasoning_content = getattr(delta, 'reasoning_content', None) if reasoning_content: if not reasoning_started: reasoning_started = True in_reasoning = True yield "
\n💭 思考过程\n\n" yield reasoning_content # 最终回答 if delta.content: if in_reasoning: in_reasoning = False yield "\n
\n\n" yield delta.content async def _chat_anthropic( self, model: str, messages: List[dict], system_prompt: str, stream: bool, ) -> Union[str, AsyncGenerator[str, None]]: """Anthropic接口调用(使用self.anthropic_client)""" return await self._chat_anthropic_with_client(self.anthropic_client, model, messages, system_prompt, stream) async def _chat_anthropic_with_client( self, client, model: str, messages: List[dict], system_prompt: str, stream: bool, ) -> Union[str, AsyncGenerator[str, None]]: """Anthropic接口调用""" if stream: return self._stream_anthropic_with_client(client, model, messages, system_prompt) else: response = await client.messages.create( model=model, max_tokens=4096, system=system_prompt if system_prompt else "You are a helpful assistant.", messages=messages, ) return response.content[0].text async def _stream_anthropic( self, model: str, messages: List[dict], system_prompt: str, ) -> AsyncGenerator[str, None]: """Anthropic流式输出(使用self.anthropic_client)""" async for text in self._stream_anthropic_with_client(self.anthropic_client, model, messages, system_prompt): yield text async def _stream_anthropic_with_client( self, client, model: str, messages: List[dict], system_prompt: str, ) -> AsyncGenerator[str, None]: """Anthropic流式输出""" async with client.messages.stream( model=model, max_tokens=4096, system=system_prompt if system_prompt else "You are a helpful assistant.", messages=messages, ) as stream: async for text in stream.text_stream: yield text # 单例 ai_service = AIService()