275 lines
11 KiB
Python
275 lines
11 KiB
Python
"""联网搜索路由 - 使用豆包大模型 + 火山方舟 web_search"""
|
||
import json
|
||
import httpx
|
||
from typing import List, Optional
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from sqlalchemy.orm import Session
|
||
from pydantic import BaseModel
|
||
|
||
from database import get_db
|
||
from models.user import User
|
||
from models.conversation import Conversation, Message
|
||
from schemas.conversation import ConversationResponse, ConversationDetail, MessageResponse
|
||
from routers.auth import get_current_user
|
||
from config import ARK_API_KEY, ARK_ENDPOINT, ARK_BASE_URL
|
||
|
||
router = APIRouter()
|
||
|
||
WEB_SEARCH_SYSTEM_PROMPT = """你是一个智能联网搜索助手。你可以通过联网搜索获取最新信息来回答用户的问题。
|
||
|
||
回答要求:
|
||
1. 基于搜索到的最新信息给出准确、详细的回答
|
||
2. 使用清晰的 Markdown 格式组织内容
|
||
3. 如果涉及时效性信息,注明信息的时间
|
||
4. 对搜索结果进行整合和总结,而非简单罗列
|
||
5. 如果搜索结果不足以回答问题,诚实告知并给出建议"""
|
||
|
||
|
||
class WebSearchRequest(BaseModel):
|
||
conversation_id: Optional[int] = None
|
||
content: str
|
||
model_config_id: Optional[int] = None
|
||
|
||
|
||
@router.post("/search")
|
||
async def web_search(
|
||
request: WebSearchRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""联网搜索 - 流式输出"""
|
||
# 获取模型配置:优先用指定的,否则用 .env 的 ARK 配置
|
||
from services.ai_service import _get_db_model_config
|
||
ark_api_key = ARK_API_KEY
|
||
ark_endpoint = ARK_ENDPOINT
|
||
ark_base_url = ARK_BASE_URL
|
||
search_count = 5
|
||
|
||
if request.model_config_id:
|
||
cfg = _get_db_model_config("", request.model_config_id)
|
||
if cfg:
|
||
ark_api_key = cfg["api_key"]
|
||
ark_endpoint = cfg["model"]
|
||
ark_base_url = cfg["base_url"] or ARK_BASE_URL
|
||
search_count = cfg.get("web_search_count", 5)
|
||
|
||
if not ark_api_key:
|
||
raise HTTPException(status_code=500, detail="未配置火山方舟 API Key")
|
||
|
||
# 创建或获取对话
|
||
if request.conversation_id:
|
||
conv = db.query(Conversation).filter(
|
||
Conversation.id == request.conversation_id,
|
||
Conversation.user_id == current_user.id,
|
||
).first()
|
||
if not conv:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
else:
|
||
conv = Conversation(
|
||
user_id=current_user.id,
|
||
title=request.content[:50] if request.content else "新搜索",
|
||
type="web_search",
|
||
)
|
||
db.add(conv)
|
||
db.commit()
|
||
db.refresh(conv)
|
||
|
||
# 保存用户消息
|
||
user_msg = Message(
|
||
conversation_id=conv.id,
|
||
role="user",
|
||
content=request.content,
|
||
)
|
||
db.add(user_msg)
|
||
db.commit()
|
||
|
||
# 构建历史消息
|
||
history_msgs = (
|
||
db.query(Message)
|
||
.filter(Message.conversation_id == conv.id)
|
||
.order_by(Message.created_at.asc())
|
||
.all()
|
||
)
|
||
|
||
messages = [{"role": "system", "content": WEB_SEARCH_SYSTEM_PROMPT}]
|
||
for msg in history_msgs:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 流式调用火山方舟 API
|
||
url = f"{ark_base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {ark_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
|
||
async def generate():
|
||
full_response = ""
|
||
try:
|
||
payload_with_search = {
|
||
"model": ark_endpoint,
|
||
"messages": messages,
|
||
"stream": True,
|
||
"tools": [
|
||
{
|
||
"type": "web_search",
|
||
"web_search": {"enable": True, "search_result_count": max(1, min(50, search_count))},
|
||
}
|
||
],
|
||
}
|
||
payload_without_search = {
|
||
"model": ark_endpoint,
|
||
"messages": messages,
|
||
"stream": True,
|
||
}
|
||
|
||
# 先尝试带联网搜索调用
|
||
use_fallback = False
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
async with client.stream("POST", url, json=payload_with_search, headers=headers) as resp:
|
||
if resp.status_code == 400:
|
||
# 模型不支持 web_search tools,降级到普通调用
|
||
await resp.aread()
|
||
use_fallback = True
|
||
elif resp.status_code != 200:
|
||
error_body = await resp.aread()
|
||
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
|
||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||
full_response = error_msg
|
||
else:
|
||
buffer = ""
|
||
async for chunk in resp.aiter_text():
|
||
buffer += chunk
|
||
while "\n" in buffer:
|
||
line, buffer = buffer.split("\n", 1)
|
||
line = line.strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
continue
|
||
try:
|
||
data = json.loads(data_str)
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_response += content
|
||
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 降级:不带 web_search tools 重试
|
||
if use_fallback:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
async with client.stream("POST", url, json=payload_without_search, headers=headers) as resp:
|
||
if resp.status_code != 200:
|
||
error_body = await resp.aread()
|
||
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
|
||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||
full_response = error_msg
|
||
else:
|
||
buffer = ""
|
||
async for chunk in resp.aiter_text():
|
||
buffer += chunk
|
||
while "\n" in buffer:
|
||
line, buffer = buffer.split("\n", 1)
|
||
line = line.strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
continue
|
||
try:
|
||
data = json.loads(data_str)
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_response += content
|
||
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
|
||
except json.JSONDecodeError:
|
||
pass
|
||
except Exception as e:
|
||
error_msg = f"联网搜索出错: {str(e)}"
|
||
if not full_response:
|
||
full_response = error_msg
|
||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||
|
||
# 保存AI回复
|
||
if full_response:
|
||
ai_msg = Message(
|
||
conversation_id=conv.id,
|
||
role="assistant",
|
||
content=full_response,
|
||
)
|
||
db.add(ai_msg)
|
||
db.commit()
|
||
|
||
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
@router.get("/conversations", response_model=List[ConversationResponse])
|
||
def get_conversations(
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""获取联网搜索对话列表"""
|
||
conversations = (
|
||
db.query(Conversation)
|
||
.filter(Conversation.user_id == current_user.id, Conversation.type == "web_search")
|
||
.order_by(Conversation.updated_at.desc())
|
||
.all()
|
||
)
|
||
return [ConversationResponse.model_validate(c) for c in conversations]
|
||
|
||
|
||
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
|
||
def get_conversation_detail(
|
||
conversation_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""获取对话详情"""
|
||
conv = db.query(Conversation).filter(
|
||
Conversation.id == conversation_id,
|
||
Conversation.user_id == current_user.id,
|
||
).first()
|
||
if not conv:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
msgs = (
|
||
db.query(Message)
|
||
.filter(Message.conversation_id == conversation_id)
|
||
.order_by(Message.created_at.asc())
|
||
.all()
|
||
)
|
||
result = ConversationDetail.model_validate(conv)
|
||
result.messages = [MessageResponse.model_validate(m) for m in msgs]
|
||
return result
|
||
|
||
|
||
@router.delete("/conversations/{conversation_id}")
|
||
def delete_conversation(
|
||
conversation_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""删除对话"""
|
||
conv = db.query(Conversation).filter(
|
||
Conversation.id == conversation_id,
|
||
Conversation.user_id == current_user.id,
|
||
).first()
|
||
if not conv:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
|
||
db.delete(conv)
|
||
db.commit()
|
||
return {"message": "删除成功"}
|