Files
bianchengshequ/backend/routers/web_search.py

275 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""联网搜索路由 - 使用豆包大模型 + 火山方舟 web_search"""
import json
import httpx
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db
from models.user import User
from models.conversation import Conversation, Message
from schemas.conversation import ConversationResponse, ConversationDetail, MessageResponse
from routers.auth import get_current_user
from config import ARK_API_KEY, ARK_ENDPOINT, ARK_BASE_URL
router = APIRouter()
WEB_SEARCH_SYSTEM_PROMPT = """你是一个智能联网搜索助手。你可以通过联网搜索获取最新信息来回答用户的问题。
回答要求:
1. 基于搜索到的最新信息给出准确、详细的回答
2. 使用清晰的 Markdown 格式组织内容
3. 如果涉及时效性信息,注明信息的时间
4. 对搜索结果进行整合和总结,而非简单罗列
5. 如果搜索结果不足以回答问题,诚实告知并给出建议"""
class WebSearchRequest(BaseModel):
conversation_id: Optional[int] = None
content: str
model_config_id: Optional[int] = None
@router.post("/search")
async def web_search(
request: WebSearchRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""联网搜索 - 流式输出"""
# 获取模型配置:优先用指定的,否则用 .env 的 ARK 配置
from services.ai_service import _get_db_model_config
ark_api_key = ARK_API_KEY
ark_endpoint = ARK_ENDPOINT
ark_base_url = ARK_BASE_URL
search_count = 5
if request.model_config_id:
cfg = _get_db_model_config("", request.model_config_id)
if cfg:
ark_api_key = cfg["api_key"]
ark_endpoint = cfg["model"]
ark_base_url = cfg["base_url"] or ARK_BASE_URL
search_count = cfg.get("web_search_count", 5)
if not ark_api_key:
raise HTTPException(status_code=500, detail="未配置火山方舟 API Key")
# 创建或获取对话
if request.conversation_id:
conv = db.query(Conversation).filter(
Conversation.id == request.conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
else:
conv = Conversation(
user_id=current_user.id,
title=request.content[:50] if request.content else "新搜索",
type="web_search",
)
db.add(conv)
db.commit()
db.refresh(conv)
# 保存用户消息
user_msg = Message(
conversation_id=conv.id,
role="user",
content=request.content,
)
db.add(user_msg)
db.commit()
# 构建历史消息
history_msgs = (
db.query(Message)
.filter(Message.conversation_id == conv.id)
.order_by(Message.created_at.asc())
.all()
)
messages = [{"role": "system", "content": WEB_SEARCH_SYSTEM_PROMPT}]
for msg in history_msgs:
messages.append({"role": msg.role, "content": msg.content})
# 流式调用火山方舟 API
url = f"{ark_base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {ark_api_key}",
"Content-Type": "application/json",
}
async def generate():
full_response = ""
try:
payload_with_search = {
"model": ark_endpoint,
"messages": messages,
"stream": True,
"tools": [
{
"type": "web_search",
"web_search": {"enable": True, "search_result_count": max(1, min(50, search_count))},
}
],
}
payload_without_search = {
"model": ark_endpoint,
"messages": messages,
"stream": True,
}
# 先尝试带联网搜索调用
use_fallback = False
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload_with_search, headers=headers) as resp:
if resp.status_code == 400:
# 模型不支持 web_search tools降级到普通调用
await resp.aread()
use_fallback = True
elif resp.status_code != 200:
error_body = await resp.aread()
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
full_response = error_msg
else:
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
except json.JSONDecodeError:
pass
# 降级:不带 web_search tools 重试
if use_fallback:
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload_without_search, headers=headers) as resp:
if resp.status_code != 200:
error_body = await resp.aread()
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
full_response = error_msg
else:
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
except json.JSONDecodeError:
pass
except Exception as e:
error_msg = f"联网搜索出错: {str(e)}"
if not full_response:
full_response = error_msg
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
# 保存AI回复
if full_response:
ai_msg = Message(
conversation_id=conv.id,
role="assistant",
content=full_response,
)
db.add(ai_msg)
db.commit()
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@router.get("/conversations", response_model=List[ConversationResponse])
def get_conversations(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取联网搜索对话列表"""
conversations = (
db.query(Conversation)
.filter(Conversation.user_id == current_user.id, Conversation.type == "web_search")
.order_by(Conversation.updated_at.desc())
.all()
)
return [ConversationResponse.model_validate(c) for c in conversations]
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
def get_conversation_detail(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取对话详情"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
msgs = (
db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.all()
)
result = ConversationDetail.model_validate(conv)
result.messages = [MessageResponse.model_validate(m) for m in msgs]
return result
@router.delete("/conversations/{conversation_id}")
def delete_conversation(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除对话"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
db.delete(conv)
db.commit()
return {"message": "删除成功"}