Files
bianchengshequ/backend/routers/knowledge_base.py

634 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""团队知识库路由"""
import json
import hashlib
from typing import Optional, List
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Header, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import func as sa_func, or_
from pydantic import BaseModel
from database import get_db
from config import SECRET_KEY
from models.user import User
from models.post import Post
from models.system_config import SystemConfig
from models.knowledge_base import KbCategory, KbItem, KbAccessLog
from models.attachment import Attachment
from routers.auth import get_current_user, get_admin_user
from services.ai_service import ai_service
router = APIRouter()
# ========== 密码机制(复用 API Hub 方案) ==========
def _get_kb_password(db: Session) -> str:
cfg = db.query(SystemConfig).filter(SystemConfig.key == "kb_password").first()
return cfg.value if cfg else ""
def _password_version(db: Session) -> str:
"""返回密码哈希前8位作为版本标识密码变更后旧token自动失效"""
pwd = _get_kb_password(db)
return pwd[:8] if pwd else "none"
def _hash_password(pwd: str) -> str:
return hashlib.sha256(pwd.encode()).hexdigest()
def _create_kb_token(user_id: int, pwd_ver: str = "none") -> str:
from jose import jwt
exp = datetime.utcnow() + timedelta(hours=2)
return jwt.encode({"sub": str(user_id), "kb": True, "pv": pwd_ver, "exp": exp}, SECRET_KEY, algorithm="HS256")
def verify_kb_access(
x_kb_token: Optional[str] = Header(None),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""验证用户登录 + 知识库访问令牌"""
if not x_kb_token:
raise HTTPException(status_code=403, detail="需要知识库访问权限,请先验证密码")
from jose import jwt, JWTError
try:
payload = jwt.decode(x_kb_token, SECRET_KEY, algorithms=["HS256"])
if not payload.get("kb"):
raise HTTPException(status_code=403, detail="无效的知识库令牌")
# 检查密码版本是否匹配
token_pv = payload.get("pv", "")
current_pv = _password_version(db)
if token_pv != current_pv:
raise HTTPException(status_code=403, detail="密码已变更,请重新验证")
except JWTError:
raise HTTPException(status_code=403, detail="知识库令牌已过期,请重新验证密码")
return current_user
# ========== Schemas ==========
class KbCategoryCreate(BaseModel):
name: str
icon: str = ""
class KbCategoryUpdate(BaseModel):
name: Optional[str] = None
icon: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class KbItemAdd(BaseModel):
post_ids: List[int]
category_id: Optional[int] = None
class KbItemUpdate(BaseModel):
category_id: Optional[int] = None
title: Optional[str] = None
summary: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class KbAiChatRequest(BaseModel):
question: str
class PasswordBody(BaseModel):
password: str
# ========== 密码认证接口 ==========
@router.post("/auth")
def kb_auth(
body: PasswordBody,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""知识库密码验证"""
stored = _get_kb_password(db)
pv = _password_version(db)
if not stored:
# 未设置密码,直接放行
token = _create_kb_token(current_user.id, pv)
return {"token": token}
if _hash_password(body.password) != stored:
raise HTTPException(status_code=401, detail="密码错误")
token = _create_kb_token(current_user.id, pv)
return {"token": token}
@router.get("/check-password")
def kb_check_password(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""检查知识库是否需要密码"""
stored = _get_kb_password(db)
return {"has_password": bool(stored)}
# ========== 公开接口(需登录 + kb_token ==========
@router.get("/categories")
def get_categories(
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库分类列表"""
cats = db.query(KbCategory).filter(KbCategory.is_active == True).order_by(KbCategory.sort_order, KbCategory.id).all()
result = []
for c in cats:
count = db.query(sa_func.count(KbItem.id)).filter(KbItem.category_id == c.id, KbItem.is_active == True).scalar() or 0
result.append({"id": c.id, "name": c.name, "icon": c.icon, "count": count})
return result
@router.get("/items")
def get_items(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
category_id: Optional[int] = None,
keyword: Optional[str] = None,
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库条目列表"""
query = db.query(KbItem).filter(KbItem.is_active == True)
if category_id:
query = query.filter(KbItem.category_id == category_id)
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
total = query.count()
items = query.order_by(KbItem.sort_order, KbItem.created_at.desc()).offset((page - 1) * size).limit(size).all()
result = []
for item in items:
post = db.query(Post).filter(Post.id == item.post_id).first()
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
author = db.query(User).filter(User.id == item.added_by).first() if item.added_by else None
result.append({
"id": item.id,
"title": item.title,
"summary": item.summary or (post.content[:150] if post else ""),
"category_id": item.category_id,
"category_name": cat.name if cat else "",
"post_id": item.post_id,
"post_author": post.user_id if post else None,
"added_by_name": author.username if author else "",
"created_at": item.created_at.isoformat() if item.created_at else None,
})
# 记录访问日志
if keyword:
db.add(KbAccessLog(user_id=user.id, action="search", query=keyword))
db.commit()
return {"items": result, "total": total, "page": page, "size": size}
@router.get("/items/{item_id}")
def get_item_detail(
item_id: int,
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库条目详情(含帖子完整内容)"""
item = db.query(KbItem).filter(KbItem.id == item_id, KbItem.is_active == True).first()
if not item:
raise HTTPException(status_code=404, detail="条目不存在")
post = db.query(Post).filter(Post.id == item.post_id).first()
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
post_author = db.query(User).filter(User.id == post.user_id).first() if post else None
# 记录访问
db.add(KbAccessLog(user_id=user.id, action="view", query=str(item_id)))
db.commit()
return {
"id": item.id,
"title": item.title,
"summary": item.summary,
"category_id": item.category_id,
"category_name": cat.name if cat else "",
"post_id": item.post_id,
"post_title": post.title if post else "",
"post_content": post.content if post else "",
"post_author": {"id": post_author.id, "username": post_author.username} if post_author else None,
"post_tags": post.tags if post else "",
"post_category": post.category if post else "",
"created_at": item.created_at.isoformat() if item.created_at else None,
"attachments": [
{"id": a.id, "filename": a.filename, "url": a.url, "file_size": a.file_size, "file_type": a.file_type}
for a in db.query(Attachment).filter(Attachment.post_id == item.post_id).order_by(Attachment.created_at.asc()).all()
] if post else [],
}
@router.get("/stats")
def get_kb_stats(
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库统计"""
total_items = db.query(sa_func.count(KbItem.id)).filter(KbItem.is_active == True).scalar() or 0
total_categories = db.query(sa_func.count(KbCategory.id)).filter(KbCategory.is_active == True).scalar() or 0
total_views = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "view").scalar() or 0
total_searches = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "search").scalar() or 0
total_ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "ai_chat").scalar() or 0
return {
"total_items": total_items,
"total_categories": total_categories,
"total_views": total_views,
"total_searches": total_searches,
"total_ai_chats": total_ai_chats,
}
# ========== AI 智能问答 ==========
KB_AI_SYSTEM_PROMPT = """你是一个团队知识库的AI助手。你的任务是根据知识库中的内容回答用户的问题。
规则:
1. 只基于提供的知识库内容回答问题,不编造信息
2. 如果知识库中没有相关内容,诚实告知用户
3. 回答时引用来源文章的标题
4. 使用 Markdown 格式组织回答
5. 回答要简洁、专业、有条理"""
@router.post("/ai-chat")
async def kb_ai_chat(
request: KbAiChatRequest,
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""AI 智能问答SSE 流式)"""
question = request.question.strip()
if not question:
raise HTTPException(status_code=400, detail="请输入问题")
# 从知识库中检索相关内容作为上下文
kw = f"%{question[:50]}%"
# 简单关键词匹配(后续可升级为向量搜索)
related_items = (
db.query(KbItem)
.filter(KbItem.is_active == True)
.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
.limit(5)
.all()
)
# 如果关键词搜索结果不足,补充最新的条目
if len(related_items) < 3:
existing_ids = [item.id for item in related_items]
extra = (
db.query(KbItem)
.filter(KbItem.is_active == True, ~KbItem.id.in_(existing_ids) if existing_ids else True)
.order_by(KbItem.created_at.desc())
.limit(10 - len(related_items))
.all()
)
related_items.extend(extra)
# 构建知识库上下文
context_parts = []
for item in related_items:
post = db.query(Post).filter(Post.id == item.post_id).first()
if post:
content = post.content[:2000] # 限制每篇长度
context_parts.append(f"### {item.title}\n{content}")
kb_context = "\n\n---\n\n".join(context_parts) if context_parts else "知识库暂无相关内容。"
messages = [
{"role": "user", "content": f"以下是知识库中的相关内容:\n\n{kb_context}\n\n---\n\n用户问题:{question}"}
]
# 记录日志
db.add(KbAccessLog(user_id=user.id, action="ai_chat", query=question))
db.commit()
# 流式调用AI
async def generate():
full_response = ""
try:
result = await ai_service.chat(
task_type="reasoning",
messages=messages,
system_prompt=KB_AI_SYSTEM_PROMPT,
stream=True,
)
if isinstance(result, str):
full_response = result
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
else:
async for chunk in result:
full_response += chunk
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
except Exception as e:
error_msg = f"AI调用出错: {str(e)}"
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
yield f"data: {json.dumps({'content': '', 'done': True})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# ========== 管理员接口 ==========
# --- 密码管理 ---
@router.put("/admin/password")
def set_kb_password(
body: PasswordBody,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""设置/修改知识库密码"""
hashed = _hash_password(body.password) if body.password else ""
row = db.query(SystemConfig).filter(SystemConfig.key == "kb_password").first()
if row:
row.value = hashed
else:
db.add(SystemConfig(key="kb_password", value=hashed, description="知识库访问密码"))
db.commit()
return {"message": "密码已更新"}
@router.get("/admin/password-status")
def get_kb_password_status(
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取知识库密码状态"""
stored = _get_kb_password(db)
return {"has_password": bool(stored)}
# --- 分类管理 ---
@router.get("/admin/categories")
def admin_list_categories(
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取所有知识库分类(含禁用)"""
cats = db.query(KbCategory).order_by(KbCategory.sort_order, KbCategory.id).all()
return [
{
"id": c.id, "name": c.name, "icon": c.icon,
"sort_order": c.sort_order, "is_active": c.is_active,
"item_count": db.query(sa_func.count(KbItem.id)).filter(KbItem.category_id == c.id).scalar() or 0,
"created_at": c.created_at.isoformat() if c.created_at else None,
}
for c in cats
]
@router.post("/admin/categories")
def admin_create_category(
data: KbCategoryCreate,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""创建知识库分类"""
exists = db.query(KbCategory).filter(KbCategory.name == data.name).first()
if exists:
raise HTTPException(status_code=400, detail="分类名称已存在")
cat = KbCategory(name=data.name, icon=data.icon)
db.add(cat)
db.commit()
db.refresh(cat)
return {"id": cat.id, "name": cat.name, "message": "创建成功"}
@router.put("/admin/categories/{cat_id}")
def admin_update_category(
cat_id: int,
data: KbCategoryUpdate,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""更新知识库分类"""
cat = db.query(KbCategory).filter(KbCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
updates = data.dict(exclude_none=True)
for key, value in updates.items():
setattr(cat, key, value)
db.commit()
return {"message": "更新成功"}
@router.delete("/admin/categories/{cat_id}")
def admin_delete_category(
cat_id: int,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""删除知识库分类"""
cat = db.query(KbCategory).filter(KbCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
# 将该分类下的条目设为未分类
db.query(KbItem).filter(KbItem.category_id == cat_id).update({"category_id": None})
db.delete(cat)
db.commit()
return {"message": "删除成功"}
# --- 条目管理 ---
@router.get("/admin/items")
def admin_list_items(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
category_id: Optional[int] = None,
keyword: Optional[str] = None,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取所有知识库条目"""
query = db.query(KbItem)
if category_id:
query = query.filter(KbItem.category_id == category_id)
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
total = query.count()
items = query.order_by(KbItem.created_at.desc()).offset((page - 1) * size).limit(size).all()
result = []
for item in items:
post = db.query(Post).filter(Post.id == item.post_id).first()
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
author = db.query(User).filter(User.id == item.added_by).first() if item.added_by else None
post_author = db.query(User).filter(User.id == post.user_id).first() if post else None
result.append({
"id": item.id,
"title": item.title,
"summary": item.summary or "",
"category_id": item.category_id,
"category_name": cat.name if cat else "未分类",
"post_id": item.post_id,
"post_title": post.title if post else "(已删除)",
"post_author_name": post_author.username if post_author else "",
"added_by_name": author.username if author else "",
"is_active": item.is_active,
"sort_order": item.sort_order,
"created_at": item.created_at.isoformat() if item.created_at else None,
})
return {"items": result, "total": total, "page": page, "size": size}
@router.post("/admin/items")
def admin_add_items(
data: KbItemAdd,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""从帖子批量添加到知识库"""
added = 0
skipped = 0
for post_id in data.post_ids:
# 检查是否已存在
exists = db.query(KbItem).filter(KbItem.post_id == post_id).first()
if exists:
skipped += 1
continue
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
skipped += 1
continue
item = KbItem(
post_id=post_id,
category_id=data.category_id,
title=post.title,
summary=post.content[:200] if post.content else "",
added_by=admin.id,
)
db.add(item)
added += 1
db.commit()
return {"message": f"已添加 {added} 条,跳过 {skipped} 条(已存在或不存在)", "added": added, "skipped": skipped}
@router.put("/admin/items/{item_id}")
def admin_update_item(
item_id: int,
data: KbItemUpdate,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""更新知识库条目"""
item = db.query(KbItem).filter(KbItem.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="条目不存在")
updates = data.dict(exclude_none=True)
for key, value in updates.items():
setattr(item, key, value)
db.commit()
return {"message": "更新成功"}
@router.delete("/admin/items/{item_id}")
def admin_delete_item(
item_id: int,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""删除知识库条目"""
item = db.query(KbItem).filter(KbItem.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="条目不存在")
db.delete(item)
db.commit()
return {"message": "删除成功"}
@router.get("/admin/posts-for-pick")
def admin_posts_for_pick(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=50),
keyword: Optional[str] = None,
category: Optional[str] = None,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取可选帖子列表(排除已加入知识库的)"""
existing_post_ids = [r[0] for r in db.query(KbItem.post_id).all()]
query = db.query(Post).filter(Post.is_public == True)
if existing_post_ids:
query = query.filter(~Post.id.in_(existing_post_ids))
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(Post.title.like(kw), Post.content.like(kw)))
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * size).limit(size).all()
result = []
for p in posts:
author = db.query(User).filter(User.id == p.user_id).first()
result.append({
"id": p.id,
"title": p.title,
"category": p.category,
"content_preview": p.content[:100] if p.content else "",
"author_name": author.username if author else "",
"like_count": p.like_count,
"view_count": p.view_count,
"created_at": p.created_at.isoformat() if p.created_at else None,
})
return {"items": result, "total": total, "page": page, "size": size}
# --- 管理员统计 ---
@router.get("/admin/stats")
def admin_kb_stats(
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""管理员统计数据"""
total_items = db.query(sa_func.count(KbItem.id)).scalar() or 0
active_items = db.query(sa_func.count(KbItem.id)).filter(KbItem.is_active == True).scalar() or 0
total_categories = db.query(sa_func.count(KbCategory.id)).scalar() or 0
total_views = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "view").scalar() or 0
total_searches = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "search").scalar() or 0
total_ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "ai_chat").scalar() or 0
# 最近7天趋势
from datetime import date
today = date.today()
daily_stats = []
for i in range(6, -1, -1):
d = today - timedelta(days=i)
day_start = datetime.combine(d, datetime.min.time())
day_end = datetime.combine(d, datetime.max.time())
views = db.query(sa_func.count(KbAccessLog.id)).filter(
KbAccessLog.action == "view", KbAccessLog.created_at.between(day_start, day_end)
).scalar() or 0
searches = db.query(sa_func.count(KbAccessLog.id)).filter(
KbAccessLog.action == "search", KbAccessLog.created_at.between(day_start, day_end)
).scalar() or 0
ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(
KbAccessLog.action == "ai_chat", KbAccessLog.created_at.between(day_start, day_end)
).scalar() or 0
daily_stats.append({"date": d.isoformat(), "views": views, "searches": searches, "ai_chats": ai_chats})
return {
"total_items": total_items,
"active_items": active_items,
"total_categories": total_categories,
"total_views": total_views,
"total_searches": total_searches,
"total_ai_chats": total_ai_chats,
"daily_stats": daily_stats,
}