Files
bianchengshequ/backend/services/ai_service.py

430 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""统一大模型调用服务 - 支持多模型路由和流式输出"""
import json
import httpx
import asyncio
from typing import AsyncGenerator, List, Optional, Union
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from config import (
MODEL_CONFIG,
OPENAI_API_KEY,
ANTHROPIC_API_KEY,
GOOGLE_API_KEY,
DEEPSEEK_API_KEY,
ARK_API_KEY,
ARK_BASE_URL,
)
def _get_db_model_config(task_type: str, model_config_id: int = None):
"""从数据库获取指定任务类型的默认模型配置,或指定 ID 的模型"""
try:
from database import SessionLocal
from models.ai_model import AIModelConfig
db = SessionLocal()
try:
# 如果指定了模型 ID直接用该模型
if model_config_id:
model = db.query(AIModelConfig).filter(
AIModelConfig.id == model_config_id,
AIModelConfig.is_enabled == True,
).first()
if model and model.api_key:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
# 否则找默认模型
model = db.query(AIModelConfig).filter(
AIModelConfig.task_type == task_type,
AIModelConfig.is_default == True,
AIModelConfig.is_enabled == True,
).first()
if model and model.api_key:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
# 没有默认的找任意一个启用且有Key的
model = db.query(AIModelConfig).filter(
AIModelConfig.task_type == task_type,
AIModelConfig.is_enabled == True,
AIModelConfig.api_key != "",
).first()
if model:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
finally:
db.close()
except Exception:
pass
return None
class AIService:
"""统一AI服务根据任务类型路由到不同大模型"""
def __init__(self):
# OpenAI客户端也用于DeepSeek等兼容API
if OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
else:
self.openai_client = None
# Anthropic客户端
if ANTHROPIC_API_KEY:
self.anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
else:
self.anthropic_client = None
# DeepSeek客户端兼容OpenAI接口
if DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=DEEPSEEK_API_KEY,
base_url="https://api.deepseek.com/v1",
)
else:
self.deepseek_client = None
def _get_client_for_provider(self, provider: str, api_key: str, base_url: str = ""):
"""根据provider动态创建客户端"""
if provider == "anthropic":
return AsyncAnthropic(api_key=api_key)
# openai/deepseek/google 都用OpenAI兼容接口
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
return AsyncOpenAI(**kwargs)
async def chat(
self,
task_type: str,
messages: List[dict],
system_prompt: str = "",
stream: bool = False,
model_config_id: int = None,
) -> Union[str, AsyncGenerator[str, None]]:
"""
统一对话接口
参数:
task_type: 任务类型 (multimodal/reasoning/lightweight)
messages: 消息列表 [{"role": "user", "content": "..."}]
system_prompt: 系统提示词
stream: 是否流式输出
model_config_id: 指定模型配置ID可选不传则用默认
"""
# 优先从数据库读取模型配置
db_config = _get_db_model_config(task_type, model_config_id)
if db_config:
provider = db_config["provider"]
model = db_config["model"]
api_key = db_config["api_key"]
base_url = db_config["base_url"]
web_search = db_config.get("web_search_enabled", False)
web_search_count = db_config.get("web_search_count", 5)
# 火山方舟/豆包 + 联网搜索(开启后自动使用,失败则降级到普通调用)
if provider == "ark" and web_search:
try:
return await self._chat_ark_web_search(api_key, base_url, model, messages, system_prompt, stream, web_search_count)
except Exception:
# 联网搜索调用失败,降级到普通调用
pass
# 火山方舟/豆包 不带联网搜索OpenAI 兼容接口)
if provider == "ark":
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
else:
kwargs["base_url"] = ARK_BASE_URL
client = AsyncOpenAI(**kwargs)
return await self._chat_openai(client, model, messages, system_prompt, stream)
if provider == "anthropic":
client = AsyncAnthropic(api_key=api_key)
return await self._chat_anthropic_with_client(client, model, messages, system_prompt, stream)
else:
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
client = AsyncOpenAI(**kwargs)
# deepseek-reasoner 需要特殊处理
if model == "deepseek-reasoner":
return await self._chat_deepseek_reasoner(client, model, messages, system_prompt, stream)
return await self._chat_openai(client, model, messages, system_prompt, stream)
# 回退到 .env 配置
config = MODEL_CONFIG.get(task_type, MODEL_CONFIG["reasoning"])
provider = config["provider"]
model = config["model"]
if provider == "anthropic" and self.anthropic_client:
return await self._chat_anthropic(model, messages, system_prompt, stream)
elif provider == "openai" and self.openai_client:
return await self._chat_openai(self.openai_client, model, messages, system_prompt, stream)
elif provider == "deepseek" and self.deepseek_client:
return await self._chat_openai(self.deepseek_client, model, messages, system_prompt, stream)
elif provider == "google":
if GOOGLE_API_KEY:
google_client = AsyncOpenAI(
api_key=GOOGLE_API_KEY,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
return await self._chat_openai(google_client, model, messages, system_prompt, stream)
# 降级
if self.deepseek_client:
return await self._chat_openai(self.deepseek_client, "deepseek-chat", messages, system_prompt, stream)
if self.openai_client:
return await self._chat_openai(self.openai_client, "gpt-4o-mini", messages, system_prompt, stream)
if self.anthropic_client:
return await self._chat_anthropic("claude-sonnet-4-20250514", messages, system_prompt, stream)
return "未配置任何AI模型请到「模型管理」页面配置模型和API Key。"
async def _chat_ark_web_search(
self, api_key: str, base_url: str, model: str,
messages: List[dict], system_prompt: str, stream: bool,
search_count: int = 5,
) -> Union[str, AsyncGenerator[str, None]]:
"""火山方舟 + 联网搜索(使用 httpx 直接调用,因 web_search 是非标准 tools 类型)"""
url = f"{base_url or ARK_BASE_URL}/chat/completions"
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
# 限制搜索条数范围 1-50
search_count = max(1, min(50, search_count or 5))
payload = {
"model": model,
"messages": full_messages,
"stream": stream,
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": search_count}}],
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if not stream:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code != 200:
# 联网搜索调用失败,抛出异常以便降级到普通调用
raise Exception(f"API调用失败 ({resp.status_code}): {resp.text[:200]}")
data = resp.json()
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
else:
# 流式调用时,先发一个预检测请求确认模型支持 web_search
async with httpx.AsyncClient(timeout=10.0) as client:
test_payload = {
"model": model,
"messages": [{"role": "user", "content": "test"}],
"stream": False,
"max_tokens": 1,
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": 1}}],
}
try:
resp = await client.post(url, json=test_payload, headers=headers)
if resp.status_code == 400:
# 模型不支持 web_search抛出异常以便降级
raise Exception("模型不支持联网搜索")
except httpx.TimeoutException:
pass # 超时不影响,继续尝试流式调用
return self._stream_ark_web_search(url, payload, headers)
async def _stream_ark_web_search(
self, url: str, payload: dict, headers: dict,
) -> AsyncGenerator[str, None]:
"""火山方舟联网搜索流式输出"""
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload, headers=headers) as resp:
if resp.status_code != 200:
error_body = await resp.aread()
yield f"API调用失败 ({resp.status_code}): {error_body.decode()[:200]}"
return
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
pass
async def _chat_openai(
self, client: AsyncOpenAI, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""OpenAI兼容接口调用"""
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
if stream:
return self._stream_openai(client, model, full_messages)
else:
response = await client.chat.completions.create(
model=model,
messages=full_messages,
temperature=0.7,
max_tokens=4096,
)
return response.choices[0].message.content
async def _stream_openai(
self, client: AsyncOpenAI, model: str, messages: List[dict],
) -> AsyncGenerator[str, None]:
"""OpenAI流式输出"""
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.7,
max_tokens=4096,
stream=True,
)
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def _chat_deepseek_reasoner(
self, client: AsyncOpenAI, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""DeepSeek Reasoner (思考模式) 专用调用
注意deepseek-reasoner 不支持 temperature/top_p/system 等参数
输出包含 reasoning_content思考过程和 content最终回答
"""
# reasoner 不支持 system role将 system prompt 合并到第一条用户消息
full_messages = []
for msg in messages:
full_messages.append(msg)
if system_prompt and full_messages:
first_user = None
for m in full_messages:
if m["role"] == "user":
first_user = m
break
if first_user:
first_user["content"] = f"[指令] {system_prompt}\n\n[用户输入] {first_user['content']}"
if stream:
return self._stream_deepseek_reasoner(client, model, full_messages)
else:
response = await client.chat.completions.create(
model=model,
messages=full_messages,
max_tokens=8192,
)
reasoning = getattr(response.choices[0].message, 'reasoning_content', '') or ''
content = response.choices[0].message.content or ''
if reasoning:
return f"<think>\n{reasoning}\n</think>\n\n{content}"
return content
async def _stream_deepseek_reasoner(
self, client: AsyncOpenAI, model: str, messages: List[dict],
) -> AsyncGenerator[str, None]:
"""DeepSeek Reasoner 流式输出 - 包含思考过程和最终回答"""
response = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=8192,
stream=True,
)
in_reasoning = False
reasoning_started = False
async for chunk in response:
delta = chunk.choices[0].delta
# 思考过程
reasoning_content = getattr(delta, 'reasoning_content', None)
if reasoning_content:
if not reasoning_started:
reasoning_started = True
in_reasoning = True
yield "<details>\n<summary>💭 思考过程</summary>\n\n"
yield reasoning_content
# 最终回答
if delta.content:
if in_reasoning:
in_reasoning = False
yield "\n</details>\n\n"
yield delta.content
async def _chat_anthropic(
self, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""Anthropic接口调用使用self.anthropic_client"""
return await self._chat_anthropic_with_client(self.anthropic_client, model, messages, system_prompt, stream)
async def _chat_anthropic_with_client(
self, client, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""Anthropic接口调用"""
if stream:
return self._stream_anthropic_with_client(client, model, messages, system_prompt)
else:
response = await client.messages.create(
model=model,
max_tokens=4096,
system=system_prompt if system_prompt else "You are a helpful assistant.",
messages=messages,
)
return response.content[0].text
async def _stream_anthropic(
self, model: str, messages: List[dict], system_prompt: str,
) -> AsyncGenerator[str, None]:
"""Anthropic流式输出使用self.anthropic_client"""
async for text in self._stream_anthropic_with_client(self.anthropic_client, model, messages, system_prompt):
yield text
async def _stream_anthropic_with_client(
self, client, model: str, messages: List[dict], system_prompt: str,
) -> AsyncGenerator[str, None]:
"""Anthropic流式输出"""
async with client.messages.stream(
model=model,
max_tokens=4096,
system=system_prompt if system_prompt else "You are a helpful assistant.",
messages=messages,
) as stream:
async for text in stream.text_stream:
yield text
# 单例
ai_service = AIService()