初始提交:极码 GeekCode 全栈项目(FastAPI + Vue3)
This commit is contained in:
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
429
backend/services/ai_service.py
Normal file
429
backend/services/ai_service.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""统一大模型调用服务 - 支持多模型路由和流式输出"""
|
||||
import json
|
||||
import httpx
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from openai import AsyncOpenAI
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from config import (
|
||||
MODEL_CONFIG,
|
||||
OPENAI_API_KEY,
|
||||
ANTHROPIC_API_KEY,
|
||||
GOOGLE_API_KEY,
|
||||
DEEPSEEK_API_KEY,
|
||||
ARK_API_KEY,
|
||||
ARK_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
def _get_db_model_config(task_type: str, model_config_id: int = None):
|
||||
"""从数据库获取指定任务类型的默认模型配置,或指定 ID 的模型"""
|
||||
try:
|
||||
from database import SessionLocal
|
||||
from models.ai_model import AIModelConfig
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 如果指定了模型 ID,直接用该模型
|
||||
if model_config_id:
|
||||
model = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.id == model_config_id,
|
||||
AIModelConfig.is_enabled == True,
|
||||
).first()
|
||||
if model and model.api_key:
|
||||
return {
|
||||
"provider": model.provider,
|
||||
"model": model.model_id,
|
||||
"api_key": model.api_key,
|
||||
"base_url": model.base_url,
|
||||
"web_search_enabled": model.web_search_enabled,
|
||||
"web_search_count": model.web_search_count or 5,
|
||||
}
|
||||
# 否则找默认模型
|
||||
model = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.task_type == task_type,
|
||||
AIModelConfig.is_default == True,
|
||||
AIModelConfig.is_enabled == True,
|
||||
).first()
|
||||
if model and model.api_key:
|
||||
return {
|
||||
"provider": model.provider,
|
||||
"model": model.model_id,
|
||||
"api_key": model.api_key,
|
||||
"base_url": model.base_url,
|
||||
"web_search_enabled": model.web_search_enabled,
|
||||
"web_search_count": model.web_search_count or 5,
|
||||
}
|
||||
# 没有默认的,找任意一个启用且有Key的
|
||||
model = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.task_type == task_type,
|
||||
AIModelConfig.is_enabled == True,
|
||||
AIModelConfig.api_key != "",
|
||||
).first()
|
||||
if model:
|
||||
return {
|
||||
"provider": model.provider,
|
||||
"model": model.model_id,
|
||||
"api_key": model.api_key,
|
||||
"base_url": model.base_url,
|
||||
"web_search_enabled": model.web_search_enabled,
|
||||
"web_search_count": model.web_search_count or 5,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class AIService:
|
||||
"""统一AI服务,根据任务类型路由到不同大模型"""
|
||||
|
||||
def __init__(self):
|
||||
# OpenAI客户端(也用于DeepSeek等兼容API)
|
||||
if OPENAI_API_KEY:
|
||||
self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
||||
else:
|
||||
self.openai_client = None
|
||||
|
||||
# Anthropic客户端
|
||||
if ANTHROPIC_API_KEY:
|
||||
self.anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
|
||||
else:
|
||||
self.anthropic_client = None
|
||||
|
||||
# DeepSeek客户端(兼容OpenAI接口)
|
||||
if DEEPSEEK_API_KEY:
|
||||
self.deepseek_client = AsyncOpenAI(
|
||||
api_key=DEEPSEEK_API_KEY,
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
)
|
||||
else:
|
||||
self.deepseek_client = None
|
||||
|
||||
def _get_client_for_provider(self, provider: str, api_key: str, base_url: str = ""):
|
||||
"""根据provider动态创建客户端"""
|
||||
if provider == "anthropic":
|
||||
return AsyncAnthropic(api_key=api_key)
|
||||
# openai/deepseek/google 都用OpenAI兼容接口
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
return AsyncOpenAI(**kwargs)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
task_type: str,
|
||||
messages: List[dict],
|
||||
system_prompt: str = "",
|
||||
stream: bool = False,
|
||||
model_config_id: int = None,
|
||||
) -> Union[str, AsyncGenerator[str, None]]:
|
||||
"""
|
||||
统一对话接口
|
||||
|
||||
参数:
|
||||
task_type: 任务类型 (multimodal/reasoning/lightweight)
|
||||
messages: 消息列表 [{"role": "user", "content": "..."}]
|
||||
system_prompt: 系统提示词
|
||||
stream: 是否流式输出
|
||||
model_config_id: 指定模型配置ID(可选,不传则用默认)
|
||||
"""
|
||||
# 优先从数据库读取模型配置
|
||||
db_config = _get_db_model_config(task_type, model_config_id)
|
||||
if db_config:
|
||||
provider = db_config["provider"]
|
||||
model = db_config["model"]
|
||||
api_key = db_config["api_key"]
|
||||
base_url = db_config["base_url"]
|
||||
web_search = db_config.get("web_search_enabled", False)
|
||||
web_search_count = db_config.get("web_search_count", 5)
|
||||
# 火山方舟/豆包 + 联网搜索(开启后自动使用,失败则降级到普通调用)
|
||||
if provider == "ark" and web_search:
|
||||
try:
|
||||
return await self._chat_ark_web_search(api_key, base_url, model, messages, system_prompt, stream, web_search_count)
|
||||
except Exception:
|
||||
# 联网搜索调用失败,降级到普通调用
|
||||
pass
|
||||
# 火山方舟/豆包 不带联网搜索(OpenAI 兼容接口)
|
||||
if provider == "ark":
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
else:
|
||||
kwargs["base_url"] = ARK_BASE_URL
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
return await self._chat_openai(client, model, messages, system_prompt, stream)
|
||||
if provider == "anthropic":
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
return await self._chat_anthropic_with_client(client, model, messages, system_prompt, stream)
|
||||
else:
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
# deepseek-reasoner 需要特殊处理
|
||||
if model == "deepseek-reasoner":
|
||||
return await self._chat_deepseek_reasoner(client, model, messages, system_prompt, stream)
|
||||
return await self._chat_openai(client, model, messages, system_prompt, stream)
|
||||
|
||||
# 回退到 .env 配置
|
||||
config = MODEL_CONFIG.get(task_type, MODEL_CONFIG["reasoning"])
|
||||
provider = config["provider"]
|
||||
model = config["model"]
|
||||
|
||||
if provider == "anthropic" and self.anthropic_client:
|
||||
return await self._chat_anthropic(model, messages, system_prompt, stream)
|
||||
elif provider == "openai" and self.openai_client:
|
||||
return await self._chat_openai(self.openai_client, model, messages, system_prompt, stream)
|
||||
elif provider == "deepseek" and self.deepseek_client:
|
||||
return await self._chat_openai(self.deepseek_client, model, messages, system_prompt, stream)
|
||||
elif provider == "google":
|
||||
if GOOGLE_API_KEY:
|
||||
google_client = AsyncOpenAI(
|
||||
api_key=GOOGLE_API_KEY,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
)
|
||||
return await self._chat_openai(google_client, model, messages, system_prompt, stream)
|
||||
|
||||
# 降级
|
||||
if self.deepseek_client:
|
||||
return await self._chat_openai(self.deepseek_client, "deepseek-chat", messages, system_prompt, stream)
|
||||
if self.openai_client:
|
||||
return await self._chat_openai(self.openai_client, "gpt-4o-mini", messages, system_prompt, stream)
|
||||
if self.anthropic_client:
|
||||
return await self._chat_anthropic("claude-sonnet-4-20250514", messages, system_prompt, stream)
|
||||
|
||||
return "未配置任何AI模型,请到「模型管理」页面配置模型和API Key。"
|
||||
|
||||
async def _chat_ark_web_search(
|
||||
self, api_key: str, base_url: str, model: str,
|
||||
messages: List[dict], system_prompt: str, stream: bool,
|
||||
search_count: int = 5,
|
||||
) -> Union[str, AsyncGenerator[str, None]]:
|
||||
"""火山方舟 + 联网搜索(使用 httpx 直接调用,因 web_search 是非标准 tools 类型)"""
|
||||
url = f"{base_url or ARK_BASE_URL}/chat/completions"
|
||||
full_messages = []
|
||||
if system_prompt:
|
||||
full_messages.append({"role": "system", "content": system_prompt})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# 限制搜索条数范围 1-50
|
||||
search_count = max(1, min(50, search_count or 5))
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": full_messages,
|
||||
"stream": stream,
|
||||
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": search_count}}],
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if not stream:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
# 联网搜索调用失败,抛出异常以便降级到普通调用
|
||||
raise Exception(f"API调用失败 ({resp.status_code}): {resp.text[:200]}")
|
||||
data = resp.json()
|
||||
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
else:
|
||||
# 流式调用时,先发一个预检测请求确认模型支持 web_search
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
test_payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"stream": False,
|
||||
"max_tokens": 1,
|
||||
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": 1}}],
|
||||
}
|
||||
try:
|
||||
resp = await client.post(url, json=test_payload, headers=headers)
|
||||
if resp.status_code == 400:
|
||||
# 模型不支持 web_search,抛出异常以便降级
|
||||
raise Exception("模型不支持联网搜索")
|
||||
except httpx.TimeoutException:
|
||||
pass # 超时不影响,继续尝试流式调用
|
||||
return self._stream_ark_web_search(url, payload, headers)
|
||||
|
||||
async def _stream_ark_web_search(
|
||||
self, url: str, payload: dict, headers: dict,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""火山方舟联网搜索流式输出"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", url, json=payload, headers=headers) as resp:
|
||||
if resp.status_code != 200:
|
||||
error_body = await resp.aread()
|
||||
yield f"API调用失败 ({resp.status_code}): {error_body.decode()[:200]}"
|
||||
return
|
||||
buffer = ""
|
||||
async for chunk in resp.aiter_text():
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _chat_openai(
|
||||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||||
system_prompt: str, stream: bool,
|
||||
) -> Union[str, AsyncGenerator[str, None]]:
|
||||
"""OpenAI兼容接口调用"""
|
||||
full_messages = []
|
||||
if system_prompt:
|
||||
full_messages.append({"role": "system", "content": system_prompt})
|
||||
full_messages.extend(messages)
|
||||
|
||||
if stream:
|
||||
return self._stream_openai(client, model, full_messages)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=full_messages,
|
||||
temperature=0.7,
|
||||
max_tokens=4096,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _stream_openai(
|
||||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""OpenAI流式输出"""
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=4096,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
async def _chat_deepseek_reasoner(
|
||||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||||
system_prompt: str, stream: bool,
|
||||
) -> Union[str, AsyncGenerator[str, None]]:
|
||||
"""DeepSeek Reasoner (思考模式) 专用调用
|
||||
|
||||
注意:deepseek-reasoner 不支持 temperature/top_p/system 等参数
|
||||
输出包含 reasoning_content(思考过程)和 content(最终回答)
|
||||
"""
|
||||
# reasoner 不支持 system role,将 system prompt 合并到第一条用户消息
|
||||
full_messages = []
|
||||
for msg in messages:
|
||||
full_messages.append(msg)
|
||||
if system_prompt and full_messages:
|
||||
first_user = None
|
||||
for m in full_messages:
|
||||
if m["role"] == "user":
|
||||
first_user = m
|
||||
break
|
||||
if first_user:
|
||||
first_user["content"] = f"[指令] {system_prompt}\n\n[用户输入] {first_user['content']}"
|
||||
|
||||
if stream:
|
||||
return self._stream_deepseek_reasoner(client, model, full_messages)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=full_messages,
|
||||
max_tokens=8192,
|
||||
)
|
||||
reasoning = getattr(response.choices[0].message, 'reasoning_content', '') or ''
|
||||
content = response.choices[0].message.content or ''
|
||||
if reasoning:
|
||||
return f"<think>\n{reasoning}\n</think>\n\n{content}"
|
||||
return content
|
||||
|
||||
async def _stream_deepseek_reasoner(
|
||||
self, client: AsyncOpenAI, model: str, messages: List[dict],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""DeepSeek Reasoner 流式输出 - 包含思考过程和最终回答"""
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=8192,
|
||||
stream=True,
|
||||
)
|
||||
in_reasoning = False
|
||||
reasoning_started = False
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
# 思考过程
|
||||
reasoning_content = getattr(delta, 'reasoning_content', None)
|
||||
if reasoning_content:
|
||||
if not reasoning_started:
|
||||
reasoning_started = True
|
||||
in_reasoning = True
|
||||
yield "<details>\n<summary>💭 思考过程</summary>\n\n"
|
||||
yield reasoning_content
|
||||
# 最终回答
|
||||
if delta.content:
|
||||
if in_reasoning:
|
||||
in_reasoning = False
|
||||
yield "\n</details>\n\n"
|
||||
yield delta.content
|
||||
|
||||
async def _chat_anthropic(
|
||||
self, model: str, messages: List[dict],
|
||||
system_prompt: str, stream: bool,
|
||||
) -> Union[str, AsyncGenerator[str, None]]:
|
||||
"""Anthropic接口调用(使用self.anthropic_client)"""
|
||||
return await self._chat_anthropic_with_client(self.anthropic_client, model, messages, system_prompt, stream)
|
||||
|
||||
async def _chat_anthropic_with_client(
|
||||
self, client, model: str, messages: List[dict],
|
||||
system_prompt: str, stream: bool,
|
||||
) -> Union[str, AsyncGenerator[str, None]]:
|
||||
"""Anthropic接口调用"""
|
||||
if stream:
|
||||
return self._stream_anthropic_with_client(client, model, messages, system_prompt)
|
||||
else:
|
||||
response = await client.messages.create(
|
||||
model=model,
|
||||
max_tokens=4096,
|
||||
system=system_prompt if system_prompt else "You are a helpful assistant.",
|
||||
messages=messages,
|
||||
)
|
||||
return response.content[0].text
|
||||
|
||||
async def _stream_anthropic(
|
||||
self, model: str, messages: List[dict], system_prompt: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Anthropic流式输出(使用self.anthropic_client)"""
|
||||
async for text in self._stream_anthropic_with_client(self.anthropic_client, model, messages, system_prompt):
|
||||
yield text
|
||||
|
||||
async def _stream_anthropic_with_client(
|
||||
self, client, model: str, messages: List[dict], system_prompt: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Anthropic流式输出"""
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
max_tokens=4096,
|
||||
system=system_prompt if system_prompt else "You are a helpful assistant.",
|
||||
messages=messages,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
|
||||
|
||||
# 单例
|
||||
ai_service = AIService()
|
||||
Reference in New Issue
Block a user