"""统一大模型调用服务 - 支持多模型路由和流式输出"""
import json
import httpx
import asyncio
from typing import AsyncGenerator, List, Optional, Union
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from config import (
MODEL_CONFIG,
OPENAI_API_KEY,
ANTHROPIC_API_KEY,
GOOGLE_API_KEY,
DEEPSEEK_API_KEY,
ARK_API_KEY,
ARK_BASE_URL,
)
def _get_db_model_config(task_type: str, model_config_id: int = None):
"""从数据库获取指定任务类型的默认模型配置,或指定 ID 的模型"""
try:
from database import SessionLocal
from models.ai_model import AIModelConfig
db = SessionLocal()
try:
# 如果指定了模型 ID,直接用该模型
if model_config_id:
model = db.query(AIModelConfig).filter(
AIModelConfig.id == model_config_id,
AIModelConfig.is_enabled == True,
).first()
if model and model.api_key:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
# 否则找默认模型
model = db.query(AIModelConfig).filter(
AIModelConfig.task_type == task_type,
AIModelConfig.is_default == True,
AIModelConfig.is_enabled == True,
).first()
if model and model.api_key:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
# 没有默认的,找任意一个启用且有Key的
model = db.query(AIModelConfig).filter(
AIModelConfig.task_type == task_type,
AIModelConfig.is_enabled == True,
AIModelConfig.api_key != "",
).first()
if model:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
finally:
db.close()
except Exception:
pass
return None
class AIService:
"""统一AI服务,根据任务类型路由到不同大模型"""
def __init__(self):
# OpenAI客户端(也用于DeepSeek等兼容API)
if OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
else:
self.openai_client = None
# Anthropic客户端
if ANTHROPIC_API_KEY:
self.anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
else:
self.anthropic_client = None
# DeepSeek客户端(兼容OpenAI接口)
if DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=DEEPSEEK_API_KEY,
base_url="https://api.deepseek.com/v1",
)
else:
self.deepseek_client = None
def _get_client_for_provider(self, provider: str, api_key: str, base_url: str = ""):
"""根据provider动态创建客户端"""
if provider == "anthropic":
return AsyncAnthropic(api_key=api_key)
# openai/deepseek/google 都用OpenAI兼容接口
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
return AsyncOpenAI(**kwargs)
async def chat(
self,
task_type: str,
messages: List[dict],
system_prompt: str = "",
stream: bool = False,
model_config_id: int = None,
) -> Union[str, AsyncGenerator[str, None]]:
"""
统一对话接口
参数:
task_type: 任务类型 (multimodal/reasoning/lightweight)
messages: 消息列表 [{"role": "user", "content": "..."}]
system_prompt: 系统提示词
stream: 是否流式输出
model_config_id: 指定模型配置ID(可选,不传则用默认)
"""
# 优先从数据库读取模型配置
db_config = _get_db_model_config(task_type, model_config_id)
if db_config:
provider = db_config["provider"]
model = db_config["model"]
api_key = db_config["api_key"]
base_url = db_config["base_url"]
web_search = db_config.get("web_search_enabled", False)
web_search_count = db_config.get("web_search_count", 5)
# 火山方舟/豆包 + 联网搜索(开启后自动使用,失败则降级到普通调用)
if provider == "ark" and web_search:
try:
return await self._chat_ark_web_search(api_key, base_url, model, messages, system_prompt, stream, web_search_count)
except Exception:
# 联网搜索调用失败,降级到普通调用
pass
# 火山方舟/豆包 不带联网搜索(OpenAI 兼容接口)
if provider == "ark":
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
else:
kwargs["base_url"] = ARK_BASE_URL
client = AsyncOpenAI(**kwargs)
return await self._chat_openai(client, model, messages, system_prompt, stream)
if provider == "anthropic":
client = AsyncAnthropic(api_key=api_key)
return await self._chat_anthropic_with_client(client, model, messages, system_prompt, stream)
else:
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
client = AsyncOpenAI(**kwargs)
# deepseek-reasoner 需要特殊处理
if model == "deepseek-reasoner":
return await self._chat_deepseek_reasoner(client, model, messages, system_prompt, stream)
return await self._chat_openai(client, model, messages, system_prompt, stream)
# 回退到 .env 配置
config = MODEL_CONFIG.get(task_type, MODEL_CONFIG["reasoning"])
provider = config["provider"]
model = config["model"]
if provider == "anthropic" and self.anthropic_client:
return await self._chat_anthropic(model, messages, system_prompt, stream)
elif provider == "openai" and self.openai_client:
return await self._chat_openai(self.openai_client, model, messages, system_prompt, stream)
elif provider == "deepseek" and self.deepseek_client:
return await self._chat_openai(self.deepseek_client, model, messages, system_prompt, stream)
elif provider == "google":
if GOOGLE_API_KEY:
google_client = AsyncOpenAI(
api_key=GOOGLE_API_KEY,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
return await self._chat_openai(google_client, model, messages, system_prompt, stream)
# 降级
if self.deepseek_client:
return await self._chat_openai(self.deepseek_client, "deepseek-chat", messages, system_prompt, stream)
if self.openai_client:
return await self._chat_openai(self.openai_client, "gpt-4o-mini", messages, system_prompt, stream)
if self.anthropic_client:
return await self._chat_anthropic("claude-sonnet-4-20250514", messages, system_prompt, stream)
return "未配置任何AI模型,请到「模型管理」页面配置模型和API Key。"
async def _chat_ark_web_search(
self, api_key: str, base_url: str, model: str,
messages: List[dict], system_prompt: str, stream: bool,
search_count: int = 5,
) -> Union[str, AsyncGenerator[str, None]]:
"""火山方舟 + 联网搜索(使用 httpx 直接调用,因 web_search 是非标准 tools 类型)"""
url = f"{base_url or ARK_BASE_URL}/chat/completions"
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
# 限制搜索条数范围 1-50
search_count = max(1, min(50, search_count or 5))
payload = {
"model": model,
"messages": full_messages,
"stream": stream,
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": search_count}}],
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if not stream:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code != 200:
# 联网搜索调用失败,抛出异常以便降级到普通调用
raise Exception(f"API调用失败 ({resp.status_code}): {resp.text[:200]}")
data = resp.json()
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
else:
# 流式调用时,先发一个预检测请求确认模型支持 web_search
async with httpx.AsyncClient(timeout=10.0) as client:
test_payload = {
"model": model,
"messages": [{"role": "user", "content": "test"}],
"stream": False,
"max_tokens": 1,
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": 1}}],
}
try:
resp = await client.post(url, json=test_payload, headers=headers)
if resp.status_code == 400:
# 模型不支持 web_search,抛出异常以便降级
raise Exception("模型不支持联网搜索")
except httpx.TimeoutException:
pass # 超时不影响,继续尝试流式调用
return self._stream_ark_web_search(url, payload, headers)
async def _stream_ark_web_search(
self, url: str, payload: dict, headers: dict,
) -> AsyncGenerator[str, None]:
"""火山方舟联网搜索流式输出"""
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload, headers=headers) as resp:
if resp.status_code != 200:
error_body = await resp.aread()
yield f"API调用失败 ({resp.status_code}): {error_body.decode()[:200]}"
return
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
pass
async def _chat_openai(
self, client: AsyncOpenAI, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""OpenAI兼容接口调用"""
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
if stream:
return self._stream_openai(client, model, full_messages)
else:
response = await client.chat.completions.create(
model=model,
messages=full_messages,
temperature=0.7,
max_tokens=4096,
)
return response.choices[0].message.content
async def _stream_openai(
self, client: AsyncOpenAI, model: str, messages: List[dict],
) -> AsyncGenerator[str, None]:
"""OpenAI流式输出"""
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.7,
max_tokens=4096,
stream=True,
)
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def _chat_deepseek_reasoner(
self, client: AsyncOpenAI, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""DeepSeek Reasoner (思考模式) 专用调用
注意:deepseek-reasoner 不支持 temperature/top_p/system 等参数
输出包含 reasoning_content(思考过程)和 content(最终回答)
"""
# reasoner 不支持 system role,将 system prompt 合并到第一条用户消息
full_messages = []
for msg in messages:
full_messages.append(msg)
if system_prompt and full_messages:
first_user = None
for m in full_messages:
if m["role"] == "user":
first_user = m
break
if first_user:
first_user["content"] = f"[指令] {system_prompt}\n\n[用户输入] {first_user['content']}"
if stream:
return self._stream_deepseek_reasoner(client, model, full_messages)
else:
response = await client.chat.completions.create(
model=model,
messages=full_messages,
max_tokens=8192,
)
reasoning = getattr(response.choices[0].message, 'reasoning_content', '') or ''
content = response.choices[0].message.content or ''
if reasoning:
return f"\n{reasoning}\n\n\n{content}"
return content
async def _stream_deepseek_reasoner(
self, client: AsyncOpenAI, model: str, messages: List[dict],
) -> AsyncGenerator[str, None]:
"""DeepSeek Reasoner 流式输出 - 包含思考过程和最终回答"""
response = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=8192,
stream=True,
)
in_reasoning = False
reasoning_started = False
async for chunk in response:
delta = chunk.choices[0].delta
# 思考过程
reasoning_content = getattr(delta, 'reasoning_content', None)
if reasoning_content:
if not reasoning_started:
reasoning_started = True
in_reasoning = True
yield "\n💭 思考过程
\n\n"
yield reasoning_content
# 最终回答
if delta.content:
if in_reasoning:
in_reasoning = False
yield "\n \n\n"
yield delta.content
async def _chat_anthropic(
self, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""Anthropic接口调用(使用self.anthropic_client)"""
return await self._chat_anthropic_with_client(self.anthropic_client, model, messages, system_prompt, stream)
async def _chat_anthropic_with_client(
self, client, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""Anthropic接口调用"""
if stream:
return self._stream_anthropic_with_client(client, model, messages, system_prompt)
else:
response = await client.messages.create(
model=model,
max_tokens=4096,
system=system_prompt if system_prompt else "You are a helpful assistant.",
messages=messages,
)
return response.content[0].text
async def _stream_anthropic(
self, model: str, messages: List[dict], system_prompt: str,
) -> AsyncGenerator[str, None]:
"""Anthropic流式输出(使用self.anthropic_client)"""
async for text in self._stream_anthropic_with_client(self.anthropic_client, model, messages, system_prompt):
yield text
async def _stream_anthropic_with_client(
self, client, model: str, messages: List[dict], system_prompt: str,
) -> AsyncGenerator[str, None]:
"""Anthropic流式输出"""
async with client.messages.stream(
model=model,
max_tokens=4096,
system=system_prompt if system_prompt else "You are a helpful assistant.",
messages=messages,
) as stream:
async for text in stream.text_stream:
yield text
# 单例
ai_service = AIService()