430 lines
18 KiB
Python
430 lines
18 KiB
Python
"""统一大模型调用服务 - 支持多模型路由和流式输出"""
|
||
import json
|
||
import httpx
|
||
import asyncio
|
||
from typing import AsyncGenerator, List, Optional, Union
|
||
from openai import AsyncOpenAI
|
||
from anthropic import AsyncAnthropic
|
||
|
||
from config import (
|
||
MODEL_CONFIG,
|
||
OPENAI_API_KEY,
|
||
ANTHROPIC_API_KEY,
|
||
GOOGLE_API_KEY,
|
||
DEEPSEEK_API_KEY,
|
||
ARK_API_KEY,
|
||
ARK_BASE_URL,
|
||
)
|
||
|
||
|
||
def _get_db_model_config(task_type: str, model_config_id: int = None):
|
||
"""从数据库获取指定任务类型的默认模型配置,或指定 ID 的模型"""
|
||
try:
|
||
from database import SessionLocal
|
||
from models.ai_model import AIModelConfig
|
||
db = SessionLocal()
|
||
try:
|
||
# 如果指定了模型 ID,直接用该模型
|
||
if model_config_id:
|
||
model = db.query(AIModelConfig).filter(
|
||
AIModelConfig.id == model_config_id,
|
||
AIModelConfig.is_enabled == True,
|
||
).first()
|
||
if model and model.api_key:
|
||
return {
|
||
"provider": model.provider,
|
||
"model": model.model_id,
|
||
"api_key": model.api_key,
|
||
"base_url": model.base_url,
|
||
"web_search_enabled": model.web_search_enabled,
|
||
"web_search_count": model.web_search_count or 5,
|
||
}
|
||
# 否则找默认模型
|
||
model = db.query(AIModelConfig).filter(
|
||
AIModelConfig.task_type == task_type,
|
||
AIModelConfig.is_default == True,
|
||
AIModelConfig.is_enabled == True,
|
||
).first()
|
||
if model and model.api_key:
|
||
return {
|
||
"provider": model.provider,
|
||
"model": model.model_id,
|
||
"api_key": model.api_key,
|
||
"base_url": model.base_url,
|
||
"web_search_enabled": model.web_search_enabled,
|
||
"web_search_count": model.web_search_count or 5,
|
||
}
|
||
# 没有默认的,找任意一个启用且有Key的
|
||
model = db.query(AIModelConfig).filter(
|
||
AIModelConfig.task_type == task_type,
|
||
AIModelConfig.is_enabled == True,
|
||
AIModelConfig.api_key != "",
|
||
).first()
|
||
if model:
|
||
return {
|
||
"provider": model.provider,
|
||
"model": model.model_id,
|
||
"api_key": model.api_key,
|
||
"base_url": model.base_url,
|
||
"web_search_enabled": model.web_search_enabled,
|
||
"web_search_count": model.web_search_count or 5,
|
||
}
|
||
finally:
|
||
db.close()
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
|
||
class AIService:
|
||
"""统一AI服务,根据任务类型路由到不同大模型"""
|
||
|
||
def __init__(self):
|
||
# OpenAI客户端(也用于DeepSeek等兼容API)
|
||
if OPENAI_API_KEY:
|
||
self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
||
else:
|
||
self.openai_client = None
|
||
|
||
# Anthropic客户端
|
||
if ANTHROPIC_API_KEY:
|
||
self.anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
|
||
else:
|
||
self.anthropic_client = None
|
||
|
||
# DeepSeek客户端(兼容OpenAI接口)
|
||
if DEEPSEEK_API_KEY:
|
||
self.deepseek_client = AsyncOpenAI(
|
||
api_key=DEEPSEEK_API_KEY,
|
||
base_url="https://api.deepseek.com/v1",
|
||
)
|
||
else:
|
||
self.deepseek_client = None
|
||
|
||
def _get_client_for_provider(self, provider: str, api_key: str, base_url: str = ""):
|
||
"""根据provider动态创建客户端"""
|
||
if provider == "anthropic":
|
||
return AsyncAnthropic(api_key=api_key)
|
||
# openai/deepseek/google 都用OpenAI兼容接口
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
return AsyncOpenAI(**kwargs)
|
||
|
||
async def chat(
|
||
self,
|
||
task_type: str,
|
||
messages: List[dict],
|
||
system_prompt: str = "",
|
||
stream: bool = False,
|
||
model_config_id: int = None,
|
||
) -> Union[str, AsyncGenerator[str, None]]:
|
||
"""
|
||
统一对话接口
|
||
|
||
参数:
|
||
task_type: 任务类型 (multimodal/reasoning/lightweight)
|
||
messages: 消息列表 [{"role": "user", "content": "..."}]
|
||
system_prompt: 系统提示词
|
||
stream: 是否流式输出
|
||
model_config_id: 指定模型配置ID(可选,不传则用默认)
|
||
"""
|
||
# 优先从数据库读取模型配置
|
||
db_config = _get_db_model_config(task_type, model_config_id)
|
||
if db_config:
|
||
provider = db_config["provider"]
|
||
model = db_config["model"]
|
||
api_key = db_config["api_key"]
|
||
base_url = db_config["base_url"]
|
||
web_search = db_config.get("web_search_enabled", False)
|
||
web_search_count = db_config.get("web_search_count", 5)
|
||
# 火山方舟/豆包 + 联网搜索(开启后自动使用,失败则降级到普通调用)
|
||
if provider == "ark" and web_search:
|
||
try:
|
||
return await self._chat_ark_web_search(api_key, base_url, model, messages, system_prompt, stream, web_search_count)
|
||
except Exception:
|
||
# 联网搜索调用失败,降级到普通调用
|
||
pass
|
||
# 火山方舟/豆包 不带联网搜索(OpenAI 兼容接口)
|
||
if provider == "ark":
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
else:
|
||
kwargs["base_url"] = ARK_BASE_URL
|
||
client = AsyncOpenAI(**kwargs)
|
||
return await self._chat_openai(client, model, messages, system_prompt, stream)
|
||
if provider == "anthropic":
|
||
client = AsyncAnthropic(api_key=api_key)
|
||
return await self._chat_anthropic_with_client(client, model, messages, system_prompt, stream)
|
||
else:
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
client = AsyncOpenAI(**kwargs)
|
||
# deepseek-reasoner 需要特殊处理
|
||
if model == "deepseek-reasoner":
|
||
return await self._chat_deepseek_reasoner(client, model, messages, system_prompt, stream)
|
||
return await self._chat_openai(client, model, messages, system_prompt, stream)
|
||
|
||
# 回退到 .env 配置
|
||
config = MODEL_CONFIG.get(task_type, MODEL_CONFIG["reasoning"])
|
||
provider = config["provider"]
|
||
model = config["model"]
|
||
|
||
if provider == "anthropic" and self.anthropic_client:
|
||
return await self._chat_anthropic(model, messages, system_prompt, stream)
|
||
elif provider == "openai" and self.openai_client:
|
||
return await self._chat_openai(self.openai_client, model, messages, system_prompt, stream)
|
||
elif provider == "deepseek" and self.deepseek_client:
|
||
return await self._chat_openai(self.deepseek_client, model, messages, system_prompt, stream)
|
||
elif provider == "google":
|
||
if GOOGLE_API_KEY:
|
||
google_client = AsyncOpenAI(
|
||
api_key=GOOGLE_API_KEY,
|
||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||
)
|
||
return await self._chat_openai(google_client, model, messages, system_prompt, stream)
|
||
|
||
# 降级
|
||
if self.deepseek_client:
|
||
return await self._chat_openai(self.deepseek_client, "deepseek-chat", messages, system_prompt, stream)
|
||
if self.openai_client:
|
||
return await self._chat_openai(self.openai_client, "gpt-4o-mini", messages, system_prompt, stream)
|
||
if self.anthropic_client:
|
||
return await self._chat_anthropic("claude-sonnet-4-20250514", messages, system_prompt, stream)
|
||
|
||
return "未配置任何AI模型,请到「模型管理」页面配置模型和API Key。"
|
||
|
||
async def _chat_ark_web_search(
|
||
self, api_key: str, base_url: str, model: str,
|
||
messages: List[dict], system_prompt: str, stream: bool,
|
||
search_count: int = 5,
|
||
) -> Union[str, AsyncGenerator[str, None]]:
|
||
"""火山方舟 + 联网搜索(使用 httpx 直接调用,因 web_search 是非标准 tools 类型)"""
|
||
url = f"{base_url or ARK_BASE_URL}/chat/completions"
|
||
full_messages = []
|
||
if system_prompt:
|
||
full_messages.append({"role": "system", "content": system_prompt})
|
||
full_messages.extend(messages)
|
||
|
||
# 限制搜索条数范围 1-50
|
||
search_count = max(1, min(50, search_count or 5))
|
||
payload = {
|
||
"model": model,
|
||
"messages": full_messages,
|
||
"stream": stream,
|
||
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": search_count}}],
|
||
}
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
if not stream:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(url, json=payload, headers=headers)
|
||
if resp.status_code != 200:
|
||
# 联网搜索调用失败,抛出异常以便降级到普通调用
|
||
raise Exception(f"API调用失败 ({resp.status_code}): {resp.text[:200]}")
|
||
data = resp.json()
|
||
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||
else:
|
||
# 流式调用时,先发一个预检测请求确认模型支持 web_search
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
test_payload = {
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
"stream": False,
|
||
"max_tokens": 1,
|
||
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": 1}}],
|
||
}
|
||
try:
|
||
resp = await client.post(url, json=test_payload, headers=headers)
|
||
if resp.status_code == 400:
|
||
# 模型不支持 web_search,抛出异常以便降级
|
||
raise Exception("模型不支持联网搜索")
|
||
except httpx.TimeoutException:
|
||
pass # 超时不影响,继续尝试流式调用
|
||
return self._stream_ark_web_search(url, payload, headers)
|
||
|
||
async def _stream_ark_web_search(
|
||
self, url: str, payload: dict, headers: dict,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""火山方舟联网搜索流式输出"""
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
async with client.stream("POST", url, json=payload, headers=headers) as resp:
|
||
if resp.status_code != 200:
|
||
error_body = await resp.aread()
|
||
yield f"API调用失败 ({resp.status_code}): {error_body.decode()[:200]}"
|
||
return
|
||
buffer = ""
|
||
async for chunk in resp.aiter_text():
|
||
buffer += chunk
|
||
while "\n" in buffer:
|
||
line, buffer = buffer.split("\n", 1)
|
||
line = line.strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
continue
|
||
try:
|
||
data = json.loads(data_str)
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
yield content
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
async def _chat_openai(
|
||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||
system_prompt: str, stream: bool,
|
||
) -> Union[str, AsyncGenerator[str, None]]:
|
||
"""OpenAI兼容接口调用"""
|
||
full_messages = []
|
||
if system_prompt:
|
||
full_messages.append({"role": "system", "content": system_prompt})
|
||
full_messages.extend(messages)
|
||
|
||
if stream:
|
||
return self._stream_openai(client, model, full_messages)
|
||
else:
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=full_messages,
|
||
temperature=0.7,
|
||
max_tokens=4096,
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
async def _stream_openai(
|
||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||
) -> AsyncGenerator[str, None]:
|
||
"""OpenAI流式输出"""
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=0.7,
|
||
max_tokens=4096,
|
||
stream=True,
|
||
)
|
||
async for chunk in response:
|
||
if chunk.choices[0].delta.content:
|
||
yield chunk.choices[0].delta.content
|
||
|
||
async def _chat_deepseek_reasoner(
|
||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||
system_prompt: str, stream: bool,
|
||
) -> Union[str, AsyncGenerator[str, None]]:
|
||
"""DeepSeek Reasoner (思考模式) 专用调用
|
||
|
||
注意:deepseek-reasoner 不支持 temperature/top_p/system 等参数
|
||
输出包含 reasoning_content(思考过程)和 content(最终回答)
|
||
"""
|
||
# reasoner 不支持 system role,将 system prompt 合并到第一条用户消息
|
||
full_messages = []
|
||
for msg in messages:
|
||
full_messages.append(msg)
|
||
if system_prompt and full_messages:
|
||
first_user = None
|
||
for m in full_messages:
|
||
if m["role"] == "user":
|
||
first_user = m
|
||
break
|
||
if first_user:
|
||
first_user["content"] = f"[指令] {system_prompt}\n\n[用户输入] {first_user['content']}"
|
||
|
||
if stream:
|
||
return self._stream_deepseek_reasoner(client, model, full_messages)
|
||
else:
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=full_messages,
|
||
max_tokens=8192,
|
||
)
|
||
reasoning = getattr(response.choices[0].message, 'reasoning_content', '') or ''
|
||
content = response.choices[0].message.content or ''
|
||
if reasoning:
|
||
return f"<think>\n{reasoning}\n</think>\n\n{content}"
|
||
return content
|
||
|
||
async def _stream_deepseek_reasoner(
|
||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||
) -> AsyncGenerator[str, None]:
|
||
"""DeepSeek Reasoner 流式输出 - 包含思考过程和最终回答"""
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
max_tokens=8192,
|
||
stream=True,
|
||
)
|
||
in_reasoning = False
|
||
reasoning_started = False
|
||
async for chunk in response:
|
||
delta = chunk.choices[0].delta
|
||
# 思考过程
|
||
reasoning_content = getattr(delta, 'reasoning_content', None)
|
||
if reasoning_content:
|
||
if not reasoning_started:
|
||
reasoning_started = True
|
||
in_reasoning = True
|
||
yield "<details>\n<summary>💭 思考过程</summary>\n\n"
|
||
yield reasoning_content
|
||
# 最终回答
|
||
if delta.content:
|
||
if in_reasoning:
|
||
in_reasoning = False
|
||
yield "\n</details>\n\n"
|
||
yield delta.content
|
||
|
||
async def _chat_anthropic(
|
||
self, model: str, messages: List[dict],
|
||
system_prompt: str, stream: bool,
|
||
) -> Union[str, AsyncGenerator[str, None]]:
|
||
"""Anthropic接口调用(使用self.anthropic_client)"""
|
||
return await self._chat_anthropic_with_client(self.anthropic_client, model, messages, system_prompt, stream)
|
||
|
||
async def _chat_anthropic_with_client(
|
||
self, client, model: str, messages: List[dict],
|
||
system_prompt: str, stream: bool,
|
||
) -> Union[str, AsyncGenerator[str, None]]:
|
||
"""Anthropic接口调用"""
|
||
if stream:
|
||
return self._stream_anthropic_with_client(client, model, messages, system_prompt)
|
||
else:
|
||
response = await client.messages.create(
|
||
model=model,
|
||
max_tokens=4096,
|
||
system=system_prompt if system_prompt else "You are a helpful assistant.",
|
||
messages=messages,
|
||
)
|
||
return response.content[0].text
|
||
|
||
async def _stream_anthropic(
|
||
self, model: str, messages: List[dict], system_prompt: str,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""Anthropic流式输出(使用self.anthropic_client)"""
|
||
async for text in self._stream_anthropic_with_client(self.anthropic_client, model, messages, system_prompt):
|
||
yield text
|
||
|
||
async def _stream_anthropic_with_client(
|
||
self, client, model: str, messages: List[dict], system_prompt: str,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""Anthropic流式输出"""
|
||
async with client.messages.stream(
|
||
model=model,
|
||
max_tokens=4096,
|
||
system=system_prompt if system_prompt else "You are a helpful assistant.",
|
||
messages=messages,
|
||
) as stream:
|
||
async for text in stream.text_stream:
|
||
yield text
|
||
|
||
|
||
# 单例
|
||
ai_service = AIService()
|