初始提交:极码 GeekCode 全栈项目(FastAPI + Vue3)
This commit is contained in:
274
backend/routers/web_search.py
Normal file
274
backend/routers/web_search.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""联网搜索路由 - 使用豆包大模型 + 火山方舟 web_search"""
|
||||
import json
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.conversation import Conversation, Message
|
||||
from schemas.conversation import ConversationResponse, ConversationDetail, MessageResponse
|
||||
from routers.auth import get_current_user
|
||||
from config import ARK_API_KEY, ARK_ENDPOINT, ARK_BASE_URL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
WEB_SEARCH_SYSTEM_PROMPT = """你是一个智能联网搜索助手。你可以通过联网搜索获取最新信息来回答用户的问题。
|
||||
|
||||
回答要求:
|
||||
1. 基于搜索到的最新信息给出准确、详细的回答
|
||||
2. 使用清晰的 Markdown 格式组织内容
|
||||
3. 如果涉及时效性信息,注明信息的时间
|
||||
4. 对搜索结果进行整合和总结,而非简单罗列
|
||||
5. 如果搜索结果不足以回答问题,诚实告知并给出建议"""
|
||||
|
||||
|
||||
class WebSearchRequest(BaseModel):
|
||||
conversation_id: Optional[int] = None
|
||||
content: str
|
||||
model_config_id: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def web_search(
|
||||
request: WebSearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""联网搜索 - 流式输出"""
|
||||
# 获取模型配置:优先用指定的,否则用 .env 的 ARK 配置
|
||||
from services.ai_service import _get_db_model_config
|
||||
ark_api_key = ARK_API_KEY
|
||||
ark_endpoint = ARK_ENDPOINT
|
||||
ark_base_url = ARK_BASE_URL
|
||||
search_count = 5
|
||||
|
||||
if request.model_config_id:
|
||||
cfg = _get_db_model_config("", request.model_config_id)
|
||||
if cfg:
|
||||
ark_api_key = cfg["api_key"]
|
||||
ark_endpoint = cfg["model"]
|
||||
ark_base_url = cfg["base_url"] or ARK_BASE_URL
|
||||
search_count = cfg.get("web_search_count", 5)
|
||||
|
||||
if not ark_api_key:
|
||||
raise HTTPException(status_code=500, detail="未配置火山方舟 API Key")
|
||||
|
||||
# 创建或获取对话
|
||||
if request.conversation_id:
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == request.conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
else:
|
||||
conv = Conversation(
|
||||
user_id=current_user.id,
|
||||
title=request.content[:50] if request.content else "新搜索",
|
||||
type="web_search",
|
||||
)
|
||||
db.add(conv)
|
||||
db.commit()
|
||||
db.refresh(conv)
|
||||
|
||||
# 保存用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.content,
|
||||
)
|
||||
db.add(user_msg)
|
||||
db.commit()
|
||||
|
||||
# 构建历史消息
|
||||
history_msgs = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": WEB_SEARCH_SYSTEM_PROMPT}]
|
||||
for msg in history_msgs:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# 流式调用火山方舟 API
|
||||
url = f"{ark_base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {ark_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def generate():
|
||||
full_response = ""
|
||||
try:
|
||||
payload_with_search = {
|
||||
"model": ark_endpoint,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"tools": [
|
||||
{
|
||||
"type": "web_search",
|
||||
"web_search": {"enable": True, "search_result_count": max(1, min(50, search_count))},
|
||||
}
|
||||
],
|
||||
}
|
||||
payload_without_search = {
|
||||
"model": ark_endpoint,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# 先尝试带联网搜索调用
|
||||
use_fallback = False
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", url, json=payload_with_search, headers=headers) as resp:
|
||||
if resp.status_code == 400:
|
||||
# 模型不支持 web_search tools,降级到普通调用
|
||||
await resp.aread()
|
||||
use_fallback = True
|
||||
elif resp.status_code != 200:
|
||||
error_body = await resp.aread()
|
||||
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
full_response = error_msg
|
||||
else:
|
||||
buffer = ""
|
||||
async for chunk in resp.aiter_text():
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
full_response += content
|
||||
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 降级:不带 web_search tools 重试
|
||||
if use_fallback:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", url, json=payload_without_search, headers=headers) as resp:
|
||||
if resp.status_code != 200:
|
||||
error_body = await resp.aread()
|
||||
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
full_response = error_msg
|
||||
else:
|
||||
buffer = ""
|
||||
async for chunk in resp.aiter_text():
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
full_response += content
|
||||
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
error_msg = f"联网搜索出错: {str(e)}"
|
||||
if not full_response:
|
||||
full_response = error_msg
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
|
||||
# 保存AI回复
|
||||
if full_response:
|
||||
ai_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="assistant",
|
||||
content=full_response,
|
||||
)
|
||||
db.add(ai_msg)
|
||||
db.commit()
|
||||
|
||||
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||
def get_conversations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取联网搜索对话列表"""
|
||||
conversations = (
|
||||
db.query(Conversation)
|
||||
.filter(Conversation.user_id == current_user.id, Conversation.type == "web_search")
|
||||
.order_by(Conversation.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [ConversationResponse.model_validate(c) for c in conversations]
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
|
||||
def get_conversation_detail(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取对话详情"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
msgs = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
result = ConversationDetail.model_validate(conv)
|
||||
result.messages = [MessageResponse.model_validate(m) for m in msgs]
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
def delete_conversation(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除对话"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
|
||||
db.delete(conv)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
Reference in New Issue
Block a user