"""联网搜索路由 - 使用豆包大模型 + 火山方舟 web_search""" import json import httpx from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from pydantic import BaseModel from database import get_db from models.user import User from models.conversation import Conversation, Message from schemas.conversation import ConversationResponse, ConversationDetail, MessageResponse from routers.auth import get_current_user from config import ARK_API_KEY, ARK_ENDPOINT, ARK_BASE_URL router = APIRouter() WEB_SEARCH_SYSTEM_PROMPT = """你是一个智能联网搜索助手。你可以通过联网搜索获取最新信息来回答用户的问题。 回答要求: 1. 基于搜索到的最新信息给出准确、详细的回答 2. 使用清晰的 Markdown 格式组织内容 3. 如果涉及时效性信息,注明信息的时间 4. 对搜索结果进行整合和总结,而非简单罗列 5. 如果搜索结果不足以回答问题,诚实告知并给出建议""" class WebSearchRequest(BaseModel): conversation_id: Optional[int] = None content: str model_config_id: Optional[int] = None @router.post("/search") async def web_search( request: WebSearchRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """联网搜索 - 流式输出""" # 获取模型配置:优先用指定的,否则用 .env 的 ARK 配置 from services.ai_service import _get_db_model_config ark_api_key = ARK_API_KEY ark_endpoint = ARK_ENDPOINT ark_base_url = ARK_BASE_URL search_count = 5 if request.model_config_id: cfg = _get_db_model_config("", request.model_config_id) if cfg: ark_api_key = cfg["api_key"] ark_endpoint = cfg["model"] ark_base_url = cfg["base_url"] or ARK_BASE_URL search_count = cfg.get("web_search_count", 5) if not ark_api_key: raise HTTPException(status_code=500, detail="未配置火山方舟 API Key") # 创建或获取对话 if request.conversation_id: conv = db.query(Conversation).filter( Conversation.id == request.conversation_id, Conversation.user_id == current_user.id, ).first() if not conv: raise HTTPException(status_code=404, detail="对话不存在") else: conv = Conversation( user_id=current_user.id, title=request.content[:50] if request.content else "新搜索", type="web_search", ) db.add(conv) db.commit() db.refresh(conv) # 保存用户消息 user_msg = Message( conversation_id=conv.id, role="user", content=request.content, ) db.add(user_msg) db.commit() # 构建历史消息 history_msgs = ( db.query(Message) .filter(Message.conversation_id == conv.id) .order_by(Message.created_at.asc()) .all() ) messages = [{"role": "system", "content": WEB_SEARCH_SYSTEM_PROMPT}] for msg in history_msgs: messages.append({"role": msg.role, "content": msg.content}) # 流式调用火山方舟 API url = f"{ark_base_url}/chat/completions" headers = { "Authorization": f"Bearer {ark_api_key}", "Content-Type": "application/json", } async def generate(): full_response = "" try: payload_with_search = { "model": ark_endpoint, "messages": messages, "stream": True, "tools": [ { "type": "web_search", "web_search": {"enable": True, "search_result_count": max(1, min(50, search_count))}, } ], } payload_without_search = { "model": ark_endpoint, "messages": messages, "stream": True, } # 先尝试带联网搜索调用 use_fallback = False async with httpx.AsyncClient(timeout=120.0) as client: async with client.stream("POST", url, json=payload_with_search, headers=headers) as resp: if resp.status_code == 400: # 模型不支持 web_search tools,降级到普通调用 await resp.aread() use_fallback = True elif resp.status_code != 200: error_body = await resp.aread() error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}" yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n" full_response = error_msg else: buffer = "" async for chunk in resp.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line or not line.startswith("data: "): continue data_str = line[6:] if data_str == "[DONE]": continue try: data = json.loads(data_str) choices = data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: full_response += content yield f"data: {json.dumps({'content': content, 'done': False})}\n\n" except json.JSONDecodeError: pass # 降级:不带 web_search tools 重试 if use_fallback: async with httpx.AsyncClient(timeout=120.0) as client: async with client.stream("POST", url, json=payload_without_search, headers=headers) as resp: if resp.status_code != 200: error_body = await resp.aread() error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}" yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n" full_response = error_msg else: buffer = "" async for chunk in resp.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line or not line.startswith("data: "): continue data_str = line[6:] if data_str == "[DONE]": continue try: data = json.loads(data_str) choices = data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: full_response += content yield f"data: {json.dumps({'content': content, 'done': False})}\n\n" except json.JSONDecodeError: pass except Exception as e: error_msg = f"联网搜索出错: {str(e)}" if not full_response: full_response = error_msg yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n" # 保存AI回复 if full_response: ai_msg = Message( conversation_id=conv.id, role="assistant", content=full_response, ) db.add(ai_msg) db.commit() yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream") @router.get("/conversations", response_model=List[ConversationResponse]) def get_conversations( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """获取联网搜索对话列表""" conversations = ( db.query(Conversation) .filter(Conversation.user_id == current_user.id, Conversation.type == "web_search") .order_by(Conversation.updated_at.desc()) .all() ) return [ConversationResponse.model_validate(c) for c in conversations] @router.get("/conversations/{conversation_id}", response_model=ConversationDetail) def get_conversation_detail( conversation_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """获取对话详情""" conv = db.query(Conversation).filter( Conversation.id == conversation_id, Conversation.user_id == current_user.id, ).first() if not conv: raise HTTPException(status_code=404, detail="对话不存在") msgs = ( db.query(Message) .filter(Message.conversation_id == conversation_id) .order_by(Message.created_at.asc()) .all() ) result = ConversationDetail.model_validate(conv) result.messages = [MessageResponse.model_validate(m) for m in msgs] return result @router.delete("/conversations/{conversation_id}") def delete_conversation( conversation_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """删除对话""" conv = db.query(Conversation).filter( Conversation.id == conversation_id, Conversation.user_id == current_user.id, ).first() if not conv: raise HTTPException(status_code=404, detail="对话不存在") db.query(Message).filter(Message.conversation_id == conversation_id).delete() db.delete(conv) db.commit() return {"message": "删除成功"}