初始提交:极码 GeekCode 全栈项目(FastAPI + Vue3)
This commit is contained in:
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
458
backend/routers/admin.py
Normal file
458
backend/routers/admin.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""后台管理路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func as sa_func, distinct
|
||||
from datetime import datetime, timedelta
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.post import Post
|
||||
from models.comment import Comment
|
||||
from models.like import Like, Collect
|
||||
from models.system_config import SystemConfig
|
||||
from models.category import Category
|
||||
from routers.auth import get_admin_user, get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ---------- 对象存储配置管理(腾讯云COS) ----------
|
||||
|
||||
COS_CONFIG_KEYS = [
|
||||
{"key": "cos_secret_id", "description": "SecretId"},
|
||||
{"key": "cos_secret_key", "description": "SecretKey"},
|
||||
{"key": "cos_bucket", "description": "Bucket(如 bianchengshequ-1250000000)"},
|
||||
{"key": "cos_region", "description": "Region(如 ap-beijing)"},
|
||||
{"key": "cos_custom_domain", "description": "自定义域名(可选,CDN加速域名)"},
|
||||
]
|
||||
|
||||
|
||||
def get_cos_config_from_db(db: Session) -> dict:
|
||||
"""从数据库读取COS配置"""
|
||||
config = {}
|
||||
for item in COS_CONFIG_KEYS:
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == item["key"]).first()
|
||||
config[item["key"]] = row.value if row else ""
|
||||
return config
|
||||
|
||||
|
||||
class CosConfigUpdate(BaseModel):
|
||||
cos_secret_id: str = ""
|
||||
cos_secret_key: Optional[str] = None
|
||||
cos_bucket: str = ""
|
||||
cos_region: str = ""
|
||||
cos_custom_domain: str = ""
|
||||
|
||||
|
||||
@router.get("/storage/config")
|
||||
async def get_storage_config(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取对象存储配置"""
|
||||
config = get_cos_config_from_db(db)
|
||||
# 脱敏 SecretKey
|
||||
secret = config.get("cos_secret_key", "")
|
||||
if secret and len(secret) > 6:
|
||||
config["cos_secret_key_masked"] = secret[:3] + "*" * (len(secret) - 6) + secret[-3:]
|
||||
else:
|
||||
config["cos_secret_key_masked"] = "*" * len(secret) if secret else ""
|
||||
config.pop("cos_secret_key", None)
|
||||
return {"config": config, "fields": COS_CONFIG_KEYS}
|
||||
|
||||
|
||||
@router.put("/storage/config")
|
||||
async def update_storage_config(
|
||||
data: CosConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""更新对象存储配置"""
|
||||
updates = data.dict(exclude_none=True)
|
||||
for key, value in updates.items():
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if row:
|
||||
row.value = value
|
||||
else:
|
||||
desc = next((i["description"] for i in COS_CONFIG_KEYS if i["key"] == key), "")
|
||||
db.add(SystemConfig(key=key, value=value, description=desc))
|
||||
db.commit()
|
||||
return {"message": "配置已保存"}
|
||||
|
||||
|
||||
@router.post("/storage/test")
|
||||
async def test_storage_connection(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""测试COS连接"""
|
||||
config = get_cos_config_from_db(db)
|
||||
secret_id = config.get("cos_secret_id", "")
|
||||
secret_key = config.get("cos_secret_key", "")
|
||||
bucket = config.get("cos_bucket", "")
|
||||
region = config.get("cos_region", "")
|
||||
|
||||
if not all([secret_id, secret_key, bucket, region]):
|
||||
raise HTTPException(status_code=400, detail="COS配置不完整,请先填写所有必填项")
|
||||
|
||||
try:
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
cos_config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
|
||||
client = CosS3Client(cos_config)
|
||||
# 尝试获取bucket信息来验证连接
|
||||
client.head_bucket(Bucket=bucket)
|
||||
return {"success": True, "message": "连接成功"}
|
||||
except ImportError:
|
||||
raise HTTPException(status_code=500, detail="服务器未安装 cos-python-sdk-v5 库,请执行 pip install cos-python-sdk-v5")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取管理后台统计数据"""
|
||||
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# 基础统计
|
||||
total_users = db.query(sa_func.count(User.id)).scalar() or 0
|
||||
total_posts = db.query(sa_func.count(Post.id)).scalar() or 0
|
||||
total_comments = db.query(sa_func.count(Comment.id)).scalar() or 0
|
||||
total_likes = db.query(sa_func.count(Like.id)).scalar() or 0
|
||||
|
||||
# 今日新增
|
||||
today_users = db.query(sa_func.count(User.id)).filter(User.created_at >= today).scalar() or 0
|
||||
today_posts = db.query(sa_func.count(Post.id)).filter(Post.created_at >= today).scalar() or 0
|
||||
|
||||
# 今日活跃(今日有发帖/评论/点赞行为的用户)
|
||||
active_post = db.query(distinct(Post.user_id)).filter(Post.created_at >= today)
|
||||
active_comment = db.query(distinct(Comment.user_id)).filter(Comment.created_at >= today)
|
||||
active_like = db.query(distinct(Like.user_id)).filter(Like.created_at >= today)
|
||||
active_ids = set()
|
||||
for row in active_post.all():
|
||||
active_ids.add(row[0])
|
||||
for row in active_comment.all():
|
||||
active_ids.add(row[0])
|
||||
for row in active_like.all():
|
||||
active_ids.add(row[0])
|
||||
today_active = len(active_ids)
|
||||
|
||||
# 7日趋势
|
||||
user_trend = []
|
||||
post_trend = []
|
||||
for i in range(6, -1, -1):
|
||||
day_start = today - timedelta(days=i)
|
||||
day_end = day_start + timedelta(days=1)
|
||||
date_str = day_start.strftime("%m-%d")
|
||||
|
||||
u_count = db.query(sa_func.count(User.id)).filter(
|
||||
User.created_at >= day_start, User.created_at < day_end
|
||||
).scalar() or 0
|
||||
p_count = db.query(sa_func.count(Post.id)).filter(
|
||||
Post.created_at >= day_start, Post.created_at < day_end
|
||||
).scalar() or 0
|
||||
|
||||
user_trend.append({"date": date_str, "count": u_count})
|
||||
post_trend.append({"date": date_str, "count": p_count})
|
||||
|
||||
# 最近注册用户
|
||||
recent_users = db.query(User).order_by(User.created_at.desc()).limit(5).all()
|
||||
recent_users_data = [
|
||||
{"id": u.id, "username": u.username, "email": u.email, "created_at": str(u.created_at)}
|
||||
for u in recent_users
|
||||
]
|
||||
|
||||
# 最近发布帖子
|
||||
recent_posts = db.query(Post).order_by(Post.created_at.desc()).limit(5).all()
|
||||
recent_posts_data = []
|
||||
for p in recent_posts:
|
||||
author = db.query(User).filter(User.id == p.user_id).first()
|
||||
recent_posts_data.append({
|
||||
"id": p.id, "title": p.title,
|
||||
"author": author.username if author else "未知",
|
||||
"created_at": str(p.created_at),
|
||||
})
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"total_posts": total_posts,
|
||||
"total_comments": total_comments,
|
||||
"total_likes": total_likes,
|
||||
"today_users": today_users,
|
||||
"today_posts": today_posts,
|
||||
"today_active": today_active,
|
||||
"user_trend": user_trend,
|
||||
"post_trend": post_trend,
|
||||
"recent_users": recent_users_data,
|
||||
"recent_posts": recent_posts_data,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(
|
||||
search: str = "",
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""用户管理列表"""
|
||||
query = db.query(User)
|
||||
if search:
|
||||
query = query.filter(User.username.contains(search))
|
||||
|
||||
total = query.count()
|
||||
users = query.order_by(User.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
items = []
|
||||
for u in users:
|
||||
post_count = db.query(sa_func.count(Post.id)).filter(Post.user_id == u.id).scalar() or 0
|
||||
comment_count = db.query(sa_func.count(Comment.id)).filter(Comment.user_id == u.id).scalar() or 0
|
||||
items.append({
|
||||
"id": u.id,
|
||||
"username": u.username,
|
||||
"email": u.email,
|
||||
"avatar": u.avatar or "",
|
||||
"is_admin": u.is_admin,
|
||||
"is_banned": getattr(u, 'is_banned', False),
|
||||
"is_approved": getattr(u, 'is_approved', True),
|
||||
"post_count": post_count,
|
||||
"comment_count": comment_count,
|
||||
"created_at": str(u.created_at),
|
||||
})
|
||||
|
||||
return {"items": items, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/toggle-admin")
|
||||
async def toggle_admin(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""切换用户管理员身份"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
if user.id == admin.id:
|
||||
raise HTTPException(status_code=400, detail="不能修改自己的管理员状态")
|
||||
user.is_admin = not user.is_admin
|
||||
db.commit()
|
||||
return {"message": f"已{'设为' if user.is_admin else '取消'}管理员", "is_admin": user.is_admin}
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/toggle-ban")
|
||||
async def toggle_ban(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""封禁/解封用户"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
if user.id == admin.id:
|
||||
raise HTTPException(status_code=400, detail="不能封禁自己")
|
||||
if user.is_admin:
|
||||
raise HTTPException(status_code=400, detail="不能封禁管理员")
|
||||
user.is_banned = not user.is_banned
|
||||
db.commit()
|
||||
return {"message": f"已{'封禁' if user.is_banned else '解封'}该用户", "is_banned": user.is_banned}
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/approve")
|
||||
async def approve_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""审核通过用户"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
if getattr(user, 'is_approved', False):
|
||||
raise HTTPException(status_code=400, detail="该用户已通过审核")
|
||||
user.is_approved = True
|
||||
db.commit()
|
||||
return {"message": f"已审核通过用户:{user.username}", "is_approved": True}
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/reject")
|
||||
async def reject_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""拒绝/撤回用户审核"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
if user.is_admin:
|
||||
raise HTTPException(status_code=400, detail="不能拒绝管理员")
|
||||
user.is_approved = False
|
||||
db.commit()
|
||||
return {"message": f"已拒绝用户:{user.username}", "is_approved": False}
|
||||
|
||||
|
||||
@router.get("/posts")
|
||||
async def list_posts(
|
||||
search: str = "",
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""帖子管理列表"""
|
||||
query = db.query(Post)
|
||||
if search:
|
||||
query = query.filter(Post.title.contains(search))
|
||||
|
||||
total = query.count()
|
||||
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
items = []
|
||||
for p in posts:
|
||||
author = db.query(User).filter(User.id == p.user_id).first()
|
||||
like_count = db.query(sa_func.count(Like.id)).filter(Like.post_id == p.id).scalar() or 0
|
||||
comment_count = db.query(sa_func.count(Comment.id)).filter(Comment.post_id == p.id).scalar() or 0
|
||||
items.append({
|
||||
"id": p.id,
|
||||
"title": p.title,
|
||||
"author": author.username if author else "未知",
|
||||
"author_id": p.user_id,
|
||||
"category": p.category or "",
|
||||
"is_public": p.is_public,
|
||||
"like_count": like_count,
|
||||
"comment_count": comment_count,
|
||||
"view_count": p.view_count,
|
||||
"created_at": str(p.created_at),
|
||||
})
|
||||
|
||||
return {"items": items, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
@router.delete("/posts/{post_id}")
|
||||
async def delete_post(
|
||||
post_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""管理员删除帖子"""
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
# 删除关联数据
|
||||
db.query(Comment).filter(Comment.post_id == post_id).delete()
|
||||
db.query(Like).filter(Like.post_id == post_id).delete()
|
||||
db.query(Collect).filter(Collect.post_id == post_id).delete()
|
||||
db.delete(post)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
# ---------- 分类管理 ----------
|
||||
|
||||
@router.get("/categories")
|
||||
async def list_categories(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取所有分类(含禁用的)"""
|
||||
cats = db.query(Category).order_by(Category.sort_order, Category.id).all()
|
||||
return [{"id": c.id, "name": c.name, "sort_order": c.sort_order, "is_active": c.is_active} for c in cats]
|
||||
|
||||
|
||||
class CategoryCreate(BaseModel):
|
||||
name: str
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
@router.post("/categories")
|
||||
async def create_category(
|
||||
data: CategoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""新增分类"""
|
||||
existing = db.query(Category).filter(Category.name == data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||||
max_order = db.query(sa_func.max(Category.sort_order)).scalar() or 0
|
||||
cat = Category(name=data.name, sort_order=max_order + 1)
|
||||
db.add(cat)
|
||||
db.commit()
|
||||
db.refresh(cat)
|
||||
return {"id": cat.id, "name": cat.name, "sort_order": cat.sort_order, "is_active": cat.is_active}
|
||||
|
||||
|
||||
@router.put("/categories/{cat_id}")
|
||||
async def update_category(
|
||||
cat_id: int,
|
||||
data: CategoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""修改分类"""
|
||||
cat = db.query(Category).filter(Category.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
if data.name is not None:
|
||||
dup = db.query(Category).filter(Category.name == data.name, Category.id != cat_id).first()
|
||||
if dup:
|
||||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||||
cat.name = data.name
|
||||
if data.sort_order is not None:
|
||||
cat.sort_order = data.sort_order
|
||||
if data.is_active is not None:
|
||||
cat.is_active = data.is_active
|
||||
db.commit()
|
||||
return {"id": cat.id, "name": cat.name, "sort_order": cat.sort_order, "is_active": cat.is_active}
|
||||
|
||||
|
||||
@router.delete("/categories/{cat_id}")
|
||||
async def delete_category(
|
||||
cat_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""删除分类"""
|
||||
cat = db.query(Category).filter(Category.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
db.delete(cat)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.put("/categories/reorder")
|
||||
async def reorder_categories(
|
||||
items: list,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""批量更新分类排序"""
|
||||
for item in items:
|
||||
cat = db.query(Category).filter(Category.id == item["id"]).first()
|
||||
if cat:
|
||||
cat.sort_order = item["sort_order"]
|
||||
db.commit()
|
||||
return {"message": "排序已更新"}
|
||||
|
||||
|
||||
# ---------- 公开分类API(无需管理员权限) ----------
|
||||
|
||||
@router.get("/public/categories")
|
||||
async def get_public_categories(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取启用的分类列表(前台使用)"""
|
||||
cats = db.query(Category).filter(Category.is_active == True).order_by(Category.sort_order, Category.id).all()
|
||||
return [c.name for c in cats]
|
||||
223
backend/routers/ai_format.py
Normal file
223
backend/routers/ai_format.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""AI智能排版路由 - 文本排版 + 自动配图"""
|
||||
import json
|
||||
import uuid
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.ai_model import AIModelConfig
|
||||
from routers.auth import get_current_user
|
||||
from services.ai_service import ai_service
|
||||
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI智能排版"])
|
||||
|
||||
FORMAT_SYSTEM_PROMPT = """你是一位专业的文章排版编辑。你的任务是将用户提供的文章内容重新排版为结构清晰、可读性强的 Markdown 格式。
|
||||
|
||||
【核心原则】绝对不能修改、删除、改写或替换原文的任何文字内容。原文的每一个字、每一句话都必须原样保留。你可以在合适的位置补充过渡语、小结或说明文字来增强可读性,但必须与原文明确区分,且不能改动原文已有的文字。
|
||||
|
||||
排版规则:
|
||||
1. 分析文章结构,在合适位置添加标题层级(## 和 ###),标题文字从原文中提取
|
||||
2. 将原文中的要点整理为列表(有序或无序),但列表内容必须是原文原句
|
||||
3. 重要观点用引用块 > 包裹,引用内容必须是原文原句
|
||||
4. 关键词和重要内容用 **加粗** 标记
|
||||
5. 保留原文中所有图片链接(格式),不要修改或删除
|
||||
6. 保留原文中所有URL链接,转为 [链接文字](url) 格式
|
||||
7. 适当添加分隔线 --- 划分章节
|
||||
8. 长段落拆分为短段落,提高可读性,但不能改变段落中的文字
|
||||
9. 如果内容中有流程或步骤,用有序列表清晰展示
|
||||
10. 不要添加原文中没有的文字、解释、总结或过渡语
|
||||
|
||||
同时,你需要分析文章内容,为文章建议 1-2 张配图。对于每张配图,在合适的位置插入占位符,格式为:
|
||||
[AI_IMAGE: 图片描述prompt,用英文写,描述要生成的图片内容,风格简洁专业]
|
||||
|
||||
注意:
|
||||
- 占位符要插在文章逻辑合适的位置(如章节开头、流程说明旁边)
|
||||
- prompt 用英文描述,风格:flat illustration, modern, professional, tech style
|
||||
- 不要超过 2 个图片占位符
|
||||
- 如果文章内容不适合配图(如纯代码、纯链接列表),可以不加占位符"""
|
||||
|
||||
|
||||
class FormatRequest(BaseModel):
|
||||
model_config = {"protected_namespaces": ()}
|
||||
content: str
|
||||
generate_images: bool = False # 是否生成配图(默认关闭)
|
||||
model_config_id: Optional[int] = None # 指定排版用的文本模型
|
||||
|
||||
|
||||
class FormatResponse(BaseModel):
|
||||
formatted_content: str
|
||||
images_generated: int = 0
|
||||
|
||||
|
||||
def _get_image_model(db: Session):
|
||||
"""从数据库查找已配置的图像生成模型(Seedream endpoint)"""
|
||||
# 查找 task_type 包含 image 或 model_name 包含 seedream 的模型
|
||||
model = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.is_enabled == True,
|
||||
AIModelConfig.task_type == "image",
|
||||
).first()
|
||||
if model:
|
||||
return {
|
||||
"api_key": model.api_key,
|
||||
"base_url": model.base_url or "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"model_id": model.model_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _generate_image(api_key: str, base_url: str, model_id: str, prompt: str) -> Optional[bytes]:
|
||||
"""调用火山方舟 Seedream 生成图片,返回图片二进制数据"""
|
||||
url = f"{base_url.rstrip('/')}/images/generations"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"prompt": prompt,
|
||||
"response_format": "url",
|
||||
"size": "1920x1080", # 满足 Seedream 5.0 最低像素要求
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
print(f"图像生成失败: {resp.status_code} - {resp.text}")
|
||||
return None
|
||||
data = resp.json()
|
||||
image_url = data.get("data", [{}])[0].get("url", "")
|
||||
if not image_url:
|
||||
return None
|
||||
# 下载图片(因为 Seedream 返回的是临时链接)
|
||||
img_resp = await client.get(image_url)
|
||||
if img_resp.status_code == 200:
|
||||
return img_resp.content
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"图像生成异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _upload_to_cos(db: Session, image_data: bytes) -> Optional[str]:
|
||||
"""将图片上传到 COS,返回永久 URL"""
|
||||
from models.system_config import SystemConfig
|
||||
keys = ["cos_secret_id", "cos_secret_key", "cos_bucket", "cos_region", "cos_custom_domain"]
|
||||
config = {}
|
||||
for k in keys:
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == k).first()
|
||||
config[k] = row.value if row else ""
|
||||
|
||||
secret_id = config.get("cos_secret_id", "")
|
||||
secret_key = config.get("cos_secret_key", "")
|
||||
bucket = config.get("cos_bucket", "")
|
||||
region = config.get("cos_region", "")
|
||||
|
||||
if not all([secret_id, secret_key, bucket, region]):
|
||||
return None
|
||||
|
||||
try:
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
cos_config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
|
||||
client = CosS3Client(cos_config)
|
||||
|
||||
date_prefix = datetime.now().strftime("%Y/%m")
|
||||
filename = f"{uuid.uuid4().hex}.png"
|
||||
object_key = f"images/{date_prefix}/{filename}"
|
||||
|
||||
client.put_object(
|
||||
Bucket=bucket,
|
||||
Body=image_data,
|
||||
Key=object_key,
|
||||
ContentType="image/png",
|
||||
)
|
||||
|
||||
custom_domain = config.get("cos_custom_domain", "")
|
||||
if custom_domain:
|
||||
return f"https://{custom_domain}/{object_key}"
|
||||
return f"https://{bucket}.cos.{region}.myqcloud.com/{object_key}"
|
||||
except Exception as e:
|
||||
print(f"COS上传失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/format", response_model=FormatResponse)
|
||||
async def format_article(
|
||||
req: FormatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""AI智能排版:格式化文章 + 自动生成配图"""
|
||||
if not req.content.strip():
|
||||
raise HTTPException(status_code=400, detail="文章内容不能为空")
|
||||
|
||||
# 第1步:AI 排版文本
|
||||
messages = [{"role": "user", "content": req.content}]
|
||||
try:
|
||||
formatted = await ai_service.chat(
|
||||
task_type="reasoning",
|
||||
messages=messages,
|
||||
system_prompt=FORMAT_SYSTEM_PROMPT,
|
||||
stream=False,
|
||||
model_config_id=req.model_config_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AI排版失败: {str(e)}")
|
||||
|
||||
if not isinstance(formatted, str):
|
||||
raise HTTPException(status_code=500, detail="AI排版返回格式异常")
|
||||
|
||||
# 过滤掉思考过程(DeepSeek Reasoner 会输出 <think>...</think>)
|
||||
import re as _re
|
||||
formatted = _re.sub(r'<think>[\s\S]*?</think>\s*', '', formatted).strip()
|
||||
# 也过滤 <details> 格式的思考过程
|
||||
formatted = _re.sub(r'<details>[\s\S]*?</details>\s*', '', formatted).strip()
|
||||
|
||||
images_generated = 0
|
||||
|
||||
# 第2步:生成配图(如果启用)
|
||||
if req.generate_images:
|
||||
import re
|
||||
placeholders = re.findall(r'\[AI_IMAGE:\s*(.+?)\]', formatted)
|
||||
|
||||
if placeholders:
|
||||
image_model = _get_image_model(db)
|
||||
if image_model:
|
||||
for prompt in placeholders:
|
||||
image_data = await _generate_image(
|
||||
image_model["api_key"],
|
||||
image_model["base_url"],
|
||||
image_model["model_id"],
|
||||
prompt.strip(),
|
||||
)
|
||||
if image_data:
|
||||
# 上传到 COS
|
||||
cos_url = _upload_to_cos(db, image_data)
|
||||
if cos_url:
|
||||
formatted = formatted.replace(
|
||||
f"[AI_IMAGE: {prompt}]",
|
||||
f"![{prompt.strip()[:50]}]({cos_url})",
|
||||
1,
|
||||
)
|
||||
images_generated += 1
|
||||
continue
|
||||
# 生成或上传失败,移除占位符
|
||||
formatted = formatted.replace(f"[AI_IMAGE: {prompt}]", "", 1)
|
||||
else:
|
||||
# 没有配置图像模型,清理所有占位符
|
||||
formatted = re.sub(r'\[AI_IMAGE:\s*.+?\]', '', formatted)
|
||||
|
||||
# 清理可能残留的占位符
|
||||
import re
|
||||
formatted = re.sub(r'\[AI_IMAGE:\s*.+?\]', '', formatted)
|
||||
# 清理多余空行
|
||||
formatted = re.sub(r'\n{3,}', '\n\n', formatted).strip()
|
||||
|
||||
return FormatResponse(
|
||||
formatted_content=formatted,
|
||||
images_generated=images_generated,
|
||||
)
|
||||
285
backend/routers/ai_models.py
Normal file
285
backend/routers/ai_models.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""AI模型管理路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from database import get_db
|
||||
from models.ai_model import AIModelConfig
|
||||
from models.user import User
|
||||
from schemas.ai_model import AIModelCreate, AIModelUpdate, AIModelResponse, ProviderInfo
|
||||
from routers.auth import get_admin_user, get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["AI模型管理"])
|
||||
|
||||
# 公开路由(登录用户可用,用于前台 AI 工具获取可选模型)
|
||||
public_router = APIRouter(prefix="/api/models", tags=["AI模型公开"])
|
||||
|
||||
# 预置的服务商和模型信息
|
||||
PROVIDER_PRESETS = [
|
||||
{
|
||||
"provider": "deepseek",
|
||||
"name": "DeepSeek",
|
||||
"default_base_url": "https://api.deepseek.com",
|
||||
"models": [
|
||||
{"model_id": "deepseek-chat", "name": "DeepSeek-V3.2", "task_types": ["lightweight", "knowledge_base"], "description": "DeepSeek-V3.2 非思考模式,性价比极高"},
|
||||
{"model_id": "deepseek-reasoner", "name": "DeepSeek-V3.2 思考", "task_types": ["reasoning", "knowledge_base"], "description": "DeepSeek-V3.2 思考模式,带推理链输出"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "OpenAI",
|
||||
"default_base_url": "https://api.openai.com/v1",
|
||||
"models": [
|
||||
{"model_id": "gpt-4o", "name": "GPT-4o", "task_types": ["multimodal", "reasoning"], "description": "多模态旗舰模型"},
|
||||
{"model_id": "gpt-4o-mini", "name": "GPT-4o Mini", "task_types": ["lightweight"], "description": "轻量高效模型"},
|
||||
{"model_id": "o3-mini", "name": "o3-mini", "task_types": ["reasoning"], "description": "推理增强模型"},
|
||||
{"model_id": "text-embedding-3-large", "name": "Embedding Large", "task_types": ["embedding"], "description": "高维度文本嵌入模型"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"provider": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"default_base_url": "https://api.anthropic.com",
|
||||
"models": [
|
||||
{"model_id": "claude-sonnet-4-20250514", "name": "Claude Sonnet 4", "task_types": ["reasoning"], "description": "Claude最新推理模型"},
|
||||
{"model_id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "task_types": ["lightweight"], "description": "快速轻量模型"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"provider": "google",
|
||||
"name": "Google Gemini",
|
||||
"default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
"models": [
|
||||
{"model_id": "gemini-2.5-pro-preview-06-05", "name": "Gemini 2.5 Pro", "task_types": ["multimodal", "reasoning"], "description": "多模态能力最强"},
|
||||
{"model_id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash", "task_types": ["lightweight", "multimodal"], "description": "快速多模态模型"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"provider": "ark",
|
||||
"name": "火山方舟(豆包)",
|
||||
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"models": [
|
||||
{"model_id": "ep-20260411180700-z6nll", "name": "Doubao-Seed-2.0-pro", "task_types": ["reasoning", "lightweight", "knowledge_base"], "description": "豆包旗舰模型,支持联网搜索"},
|
||||
{"model_id": "doubao-seedream-5-0-260128", "name": "Seedream 5.0 (图像生成)", "task_types": ["image"], "description": "豆包图像生成模型,支持文生图"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
TASK_TYPE_LABELS = {
|
||||
"multimodal": "多模态(图片/草图理解)",
|
||||
"reasoning": "推理分析(需求解读/架构分析)",
|
||||
"lightweight": "轻量任务(分类/标签)",
|
||||
"knowledge_base": "知识库分析(文档理解/问答)",
|
||||
"embedding": "向量嵌入",
|
||||
"image": "图像生成(AI配图/文生图)",
|
||||
}
|
||||
|
||||
|
||||
def _mask_api_key(key: str) -> str:
|
||||
"""API Key脱敏"""
|
||||
if not key or len(key) < 8:
|
||||
return "****" if key else ""
|
||||
return key[:4] + "*" * (len(key) - 8) + key[-4:]
|
||||
|
||||
|
||||
def _to_response(model: AIModelConfig) -> dict:
|
||||
"""转换为响应格式"""
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider": model.provider,
|
||||
"provider_name": model.provider_name,
|
||||
"model_id": model.model_id,
|
||||
"model_name": model.model_name,
|
||||
"api_key_masked": _mask_api_key(model.api_key),
|
||||
"base_url": model.base_url,
|
||||
"task_type": model.task_type,
|
||||
"is_enabled": model.is_enabled,
|
||||
"is_default": model.is_default,
|
||||
"web_search_enabled": model.web_search_enabled,
|
||||
"description": model.description,
|
||||
"created_at": model.created_at,
|
||||
"updated_at": model.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/presets", response_model=List[ProviderInfo])
|
||||
async def get_provider_presets():
|
||||
"""获取预置的服务商和模型列表"""
|
||||
return PROVIDER_PRESETS
|
||||
|
||||
|
||||
@router.get("/task-types")
|
||||
async def get_task_types():
|
||||
"""获取任务类型列表"""
|
||||
return TASK_TYPE_LABELS
|
||||
|
||||
|
||||
@router.get("", response_model=List[AIModelResponse])
|
||||
async def list_models(provider: str = None, task_type: str = None, db: Session = Depends(get_db)):
|
||||
"""获取所有已配置的模型"""
|
||||
query = db.query(AIModelConfig)
|
||||
if provider:
|
||||
query = query.filter(AIModelConfig.provider == provider)
|
||||
if task_type:
|
||||
query = query.filter(AIModelConfig.task_type == task_type)
|
||||
models = query.order_by(AIModelConfig.provider, AIModelConfig.created_at).all()
|
||||
return [_to_response(m) for m in models]
|
||||
|
||||
|
||||
@router.post("", response_model=AIModelResponse)
|
||||
async def create_model(data: AIModelCreate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||||
"""添加模型配置"""
|
||||
model = AIModelConfig(
|
||||
provider=data.provider,
|
||||
provider_name=data.provider_name,
|
||||
model_id=data.model_id,
|
||||
model_name=data.model_name,
|
||||
api_key=data.api_key,
|
||||
base_url=data.base_url,
|
||||
task_type=data.task_type,
|
||||
is_enabled=data.is_enabled,
|
||||
is_default=data.is_default,
|
||||
web_search_enabled=data.web_search_enabled,
|
||||
description=data.description,
|
||||
)
|
||||
# 如果设为默认,取消同任务类型的其他默认
|
||||
if data.is_default and data.task_type:
|
||||
db.query(AIModelConfig).filter(
|
||||
AIModelConfig.task_type == data.task_type,
|
||||
AIModelConfig.is_default == True
|
||||
).update({"is_default": False})
|
||||
db.add(model)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return _to_response(model)
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=AIModelResponse)
|
||||
async def update_model(model_id: int, data: AIModelUpdate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||||
"""更新模型配置"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="模型配置不存在")
|
||||
|
||||
update_data = data.dict(exclude_unset=True)
|
||||
|
||||
# 如果API Key为空字符串,表示不修改
|
||||
if "api_key" in update_data and update_data["api_key"] == "":
|
||||
del update_data["api_key"]
|
||||
|
||||
# 如果设为默认,取消同任务类型的其他默认
|
||||
if update_data.get("is_default") and (update_data.get("task_type") or model.task_type):
|
||||
task = update_data.get("task_type", model.task_type)
|
||||
db.query(AIModelConfig).filter(
|
||||
AIModelConfig.task_type == task,
|
||||
AIModelConfig.is_default == True,
|
||||
AIModelConfig.id != model_id
|
||||
).update({"is_default": False})
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(model, key, value)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return _to_response(model)
|
||||
|
||||
|
||||
@router.delete("/{model_id}")
|
||||
async def delete_model(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||||
"""删除模型配置"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="模型配置不存在")
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.post("/init-defaults")
|
||||
async def init_default_models(db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||||
"""初始化默认模型配置(仅当数据库为空时)"""
|
||||
count = db.query(AIModelConfig).count()
|
||||
if count > 0:
|
||||
return {"message": f"已有 {count} 条配置,跳过初始化", "count": count}
|
||||
|
||||
defaults = [
|
||||
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-chat",
|
||||
model_name="DeepSeek-V3", task_type="reasoning", is_default=True, is_enabled=True,
|
||||
base_url="https://api.deepseek.com/v1", description="DeepSeek最新对话模型,性价比极高"),
|
||||
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-reasoner",
|
||||
model_name="DeepSeek-R1", task_type="", is_enabled=True,
|
||||
base_url="https://api.deepseek.com/v1", description="深度推理模型,适合复杂逻辑分析"),
|
||||
AIModelConfig(provider="openai", provider_name="OpenAI", model_id="gpt-4o-mini",
|
||||
model_name="GPT-4o Mini", task_type="lightweight", is_default=True, is_enabled=True,
|
||||
description="轻量高效模型"),
|
||||
AIModelConfig(provider="google", provider_name="Google Gemini", model_id="gemini-2.5-pro-preview-06-05",
|
||||
model_name="Gemini 2.5 Pro", task_type="multimodal", is_default=True, is_enabled=True,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
description="多模态能力最强"),
|
||||
]
|
||||
db.add_all(defaults)
|
||||
db.commit()
|
||||
return {"message": f"已初始化 {len(defaults)} 条默认配置", "count": len(defaults)}
|
||||
|
||||
|
||||
@router.post("/{model_id}/test")
|
||||
async def test_model_connection(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
|
||||
"""测试模型连接是否正常"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="模型配置不存在")
|
||||
if not model.api_key:
|
||||
return {"success": False, "message": "未配置 API Key"}
|
||||
|
||||
try:
|
||||
import httpx
|
||||
headers = {"Authorization": f"Bearer {model.api_key}", "Content-Type": "application/json"}
|
||||
base_url = model.base_url or f"https://api.{model.provider}.com/v1"
|
||||
|
||||
if model.provider == "anthropic":
|
||||
headers = {"x-api-key": model.api_key, "anthropic-version": "2023-06-01", "Content-Type": "application/json"}
|
||||
url = f"{base_url}/v1/messages"
|
||||
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
|
||||
else:
|
||||
url = f"{base_url}/chat/completions"
|
||||
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
return {"success": True, "message": "连接成功"}
|
||||
else:
|
||||
return {"success": False, "message": f"HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"连接失败: {str(e)}"}
|
||||
|
||||
|
||||
# ===== 公开接口 =====
|
||||
|
||||
@public_router.get("/available")
|
||||
async def get_available_models(
|
||||
task_type: str = "",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取可用模型列表(登录用户可调用,不返回 API Key)"""
|
||||
query = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.is_enabled == True,
|
||||
AIModelConfig.api_key != "",
|
||||
)
|
||||
if task_type:
|
||||
query = query.filter(AIModelConfig.task_type == task_type)
|
||||
models = query.order_by(AIModelConfig.is_default.desc(), AIModelConfig.created_at).all()
|
||||
return [
|
||||
{
|
||||
"id": m.id,
|
||||
"provider": m.provider,
|
||||
"provider_name": m.provider_name,
|
||||
"model_id": m.model_id,
|
||||
"model_name": m.model_name,
|
||||
"task_type": m.task_type,
|
||||
"is_default": m.is_default,
|
||||
"web_search_enabled": m.web_search_enabled,
|
||||
"web_search_count": m.web_search_count or 5,
|
||||
"description": m.description,
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
296
backend/routers/architecture.py
Normal file
296
backend/routers/architecture.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""架构选型助手路由"""
|
||||
import json
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.conversation import Conversation, Message
|
||||
from schemas.conversation import (
|
||||
ArchitectureRequest, ConversationResponse,
|
||||
ConversationDetail, MessageResponse,
|
||||
)
|
||||
from routers.auth import get_current_user
|
||||
from services.ai_service import ai_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ARCHITECTURE_SYSTEM_PROMPT = """# 角色定义
|
||||
你是一位拥有10年+经验的**高级全栈架构师**,精通前端(Vue/React/小程序)、后端(Python/Java/Go/Node.js)、数据库(MySQL/PostgreSQL/MongoDB/Redis)、云服务与DevOps。你做过大量从0到1的项目,对技术选型的利弊、不同规模系统的架构模式了如指掌。
|
||||
|
||||
你的工作:接收用户提供的**已确认的功能需求**(可能来自需求助手的输出),给出完整的、可直接落地开发的技术方案。
|
||||
|
||||
> ⚠️ 本助手专注于**技术选型与架构设计**。如果用户发来的是原始甲方需求,建议先到「需求理解助手」进行需求分析。
|
||||
|
||||
# 核心理念
|
||||
- **没有最好的技术,只有最合适的技术**:选型必须匹配项目规模、团队能力和预算
|
||||
- **方案要能落地写代码**:不出纯理论的架构图,给的方案要具体到程序员能直接开干
|
||||
- **过度设计是大忌**:小项目用微服务是灾难,要敢于推荐简单方案
|
||||
|
||||
# 分析框架
|
||||
|
||||
## 第一步:项目画像评估
|
||||
- 项目规模:小型(个人/小团队)/ 中型(创业公司)/ 大型(企业级)
|
||||
- 预期用户量和并发量
|
||||
- 团队技术栈偏好(如果用户有提及)
|
||||
- 预算和时间约束
|
||||
|
||||
## 第二步:技术选型(带对比和理由)
|
||||
针对每一层给出推荐方案和备选方案:
|
||||
- 前端框架 + UI组件库
|
||||
- 后端语言 + Web框架
|
||||
- 数据库(主库 + 缓存)
|
||||
- 文件存储方案
|
||||
- 部署方案
|
||||
- 第三方服务(如果需要)
|
||||
|
||||
## 第三步:系统架构设计
|
||||
- 整体架构图(Mermaid语法)
|
||||
- 核心数据模型(ER关系、表结构)
|
||||
- 关键接口设计(RESTful API清单)
|
||||
- 目录结构规划
|
||||
|
||||
## 第四步:技术难点与避坑指南
|
||||
- 基于实战经验,针对该项目的具体技术难点给出解决方案
|
||||
- 常见踩坑点和规避策略
|
||||
- 安全注意事项(XSS、CSRF、SQL注入、越权等)
|
||||
|
||||
## 第五步:开发路线图
|
||||
- MVP版本应包含哪些功能
|
||||
- 迭代计划建议
|
||||
- 工期评估(按模块拆分前后端工时)
|
||||
|
||||
# 输出规范
|
||||
严格使用以下 Markdown 结构输出:
|
||||
|
||||
---
|
||||
|
||||
## 🎯 项目画像
|
||||
| 维度 | 评估 |
|
||||
|------|------|
|
||||
| 项目规模 | xxx |
|
||||
| 预期用户量 | xxx |
|
||||
| 推荐架构模式 | 单体/前后端分离/微服务 |
|
||||
|
||||
## 🏗️ 技术选型
|
||||
| 层级 | 推荐方案 | 备选方案 | 选型理由 |
|
||||
|------|---------|---------|---------|
|
||||
| 前端框架 | xxx | xxx | xxx |
|
||||
| UI组件库 | xxx | xxx | xxx |
|
||||
| 后端框架 | xxx | xxx | xxx |
|
||||
| 数据库 | xxx | xxx | xxx |
|
||||
| 缓存 | xxx | xxx | xxx |
|
||||
| 部署 | xxx | xxx | xxx |
|
||||
|
||||
## 📐 系统架构图
|
||||
```mermaid
|
||||
graph TB
|
||||
A[前端] --> B[API网关]
|
||||
B --> C[后端服务]
|
||||
C --> D[数据库]
|
||||
```
|
||||
|
||||
## 🗄️ 核心数据模型
|
||||
```sql
|
||||
-- 表名: xxx
|
||||
-- 说明: xxx
|
||||
CREATE TABLE xxx (
|
||||
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
xxx VARCHAR(255) NOT NULL COMMENT 'xxx',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 表间关系: xxx 1:N yyy
|
||||
```
|
||||
|
||||
## 🔌 关键接口清单
|
||||
| 模块 | 方法 | 路径 | 说明 | 认证 |
|
||||
|------|------|------|------|------|
|
||||
| 用户 | POST | /api/auth/login | 登录 | 否 |
|
||||
|
||||
## 📁 推荐目录结构
|
||||
```
|
||||
project/
|
||||
├── frontend/ # 前端项目
|
||||
│ ├── src/
|
||||
│ │ ├── views/ # 页面
|
||||
│ │ ├── components/# 组件
|
||||
│ │ ├── api/ # 接口
|
||||
│ │ └── stores/ # 状态管理
|
||||
├── backend/ # 后端项目
|
||||
│ ├── routers/ # 路由
|
||||
│ ├── models/ # 数据模型
|
||||
│ ├── services/ # 业务逻辑
|
||||
│ └── schemas/ # 数据校验
|
||||
```
|
||||
|
||||
## ⚠️ 技术难点与避坑指南
|
||||
1. **【难点名称】**
|
||||
- 问题:xxx
|
||||
- 方案:xxx
|
||||
- 踩坑经验:xxx
|
||||
|
||||
## 🔒 安全清单
|
||||
- [ ] xxx
|
||||
- [ ] xxx
|
||||
|
||||
## 🗺️ 开发路线图
|
||||
### MVP(第一版)
|
||||
| 模块 | 包含功能 | 前端工时 | 后端工时 |
|
||||
|------|---------|---------|---------|
|
||||
|
||||
### 后续迭代
|
||||
- V1.1: xxx
|
||||
- V1.2: xxx
|
||||
|
||||
---
|
||||
|
||||
# 交互原则
|
||||
1. **选型必须带理由**:不说"推荐用Vue",要说"推荐Vue 3,因为xxx;如果团队熟悉React也可以用"
|
||||
2. **方案要分档**:针对不同预算/规模给出不同方案(如"预算充足用云服务,省钱可以用VPS")
|
||||
3. **代码要能跑**:给出的SQL、目录结构、接口设计都要是可以直接使用的
|
||||
4. **架构图用Mermaid**:使用 ```mermaid 代码块,只用基础语法,不加样式
|
||||
5. **敢于说"不需要"**:如果项目不需要Redis/微服务/消息队列,要直说,不为了显得高级而过度设计
|
||||
6. **持续深化**:用户追问某个模块时,在已有方案基础上深入展开,保持一致性"""
|
||||
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||
def get_conversations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取架构对话列表"""
|
||||
conversations = (
|
||||
db.query(Conversation)
|
||||
.filter(Conversation.user_id == current_user.id, Conversation.type == "architecture")
|
||||
.order_by(Conversation.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [ConversationResponse.model_validate(c) for c in conversations]
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
|
||||
def get_conversation_detail(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取对话详情"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
messages = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
result = ConversationDetail.model_validate(conv)
|
||||
result.messages = [MessageResponse.model_validate(m) for m in messages]
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/recommend")
|
||||
async def recommend_architecture(
|
||||
request: ArchitectureRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""架构推荐 - 流式输出"""
|
||||
# 创建或获取对话
|
||||
if request.conversation_id:
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == request.conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
else:
|
||||
conv = Conversation(
|
||||
user_id=current_user.id,
|
||||
title=request.content[:50] if request.content else "新架构咨询",
|
||||
type="architecture",
|
||||
)
|
||||
db.add(conv)
|
||||
db.commit()
|
||||
db.refresh(conv)
|
||||
|
||||
# 保存用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.content,
|
||||
)
|
||||
db.add(user_msg)
|
||||
db.commit()
|
||||
|
||||
# 构建历史消息
|
||||
history_msgs = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in history_msgs]
|
||||
|
||||
# 流式调用AI
|
||||
async def generate():
|
||||
full_response = ""
|
||||
try:
|
||||
result = await ai_service.chat(
|
||||
task_type="reasoning",
|
||||
messages=messages,
|
||||
system_prompt=ARCHITECTURE_SYSTEM_PROMPT,
|
||||
stream=True,
|
||||
model_config_id=request.model_config_id,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
full_response = result
|
||||
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
|
||||
else:
|
||||
async for chunk in result:
|
||||
full_response += chunk
|
||||
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
|
||||
except Exception as e:
|
||||
error_msg = f"AI调用出错: {str(e)}"
|
||||
full_response = error_msg
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
|
||||
# 保存AI回复
|
||||
ai_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="assistant",
|
||||
content=full_response,
|
||||
)
|
||||
db.add(ai_msg)
|
||||
db.commit()
|
||||
|
||||
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
def delete_conversation(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除对话"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
|
||||
db.delete(conv)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
136
backend/routers/auth.py
Normal file
136
backend/routers/auth.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""认证路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from passlib.context import CryptContext
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from schemas.user import UserRegister, UserLogin, UserResponse, TokenResponse, UserUpdate
|
||||
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
router = APIRouter()
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""创建JWT Token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""从Token获取当前用户"""
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="无效的认证凭据")
|
||||
user_id = int(user_id)
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=401, detail="无效的认证凭据")
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="用户不存在")
|
||||
return user
|
||||
|
||||
|
||||
def get_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""要求当前用户是管理员"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
def register(data: UserRegister, db: Session = Depends(get_db)):
|
||||
"""用户注册(需管理员审核后才可使用)"""
|
||||
# 检查用户名是否已存在
|
||||
if db.query(User).filter(User.username == data.username).first():
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
if db.query(User).filter(User.email == data.email).first():
|
||||
raise HTTPException(status_code=400, detail="邮箱已被注册")
|
||||
|
||||
# 创建用户(is_approved 默认 False,等待审核)
|
||||
user = User(
|
||||
username=data.username,
|
||||
email=data.email,
|
||||
password_hash=pwd_context.hash(data.password),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return {"message": "注册成功,请等待管理员审核通过后即可登录使用"}
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
def login(data: UserLogin, db: Session = Depends(get_db)):
|
||||
"""用户登录"""
|
||||
user = db.query(User).filter(User.username == data.username).first()
|
||||
if not user or not pwd_context.verify(data.password, user.password_hash):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
if getattr(user, 'is_banned', False):
|
||||
raise HTTPException(status_code=403, detail="账号已被封禁,请联系管理员")
|
||||
|
||||
if not getattr(user, 'is_approved', False):
|
||||
raise HTTPException(status_code=403, detail="账号尚未通过审核,请耐心等待管理员审核")
|
||||
|
||||
token = create_access_token({"sub": str(user.id)})
|
||||
return TokenResponse(
|
||||
access_token=token,
|
||||
user=UserResponse.model_validate(user),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
def get_me(current_user: User = Depends(get_current_user)):
|
||||
"""获取当前用户信息"""
|
||||
return UserResponse.model_validate(current_user)
|
||||
|
||||
|
||||
@router.put("/profile", response_model=UserResponse)
|
||||
def update_profile(
|
||||
data: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新个人资料"""
|
||||
# 修改用户名
|
||||
if data.username and data.username != current_user.username:
|
||||
if db.query(User).filter(User.username == data.username, User.id != current_user.id).first():
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
current_user.username = data.username
|
||||
|
||||
# 修改邮箱
|
||||
if data.email and data.email != current_user.email:
|
||||
if db.query(User).filter(User.email == data.email, User.id != current_user.id).first():
|
||||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||||
current_user.email = data.email
|
||||
|
||||
# 修改头像
|
||||
if data.avatar is not None:
|
||||
current_user.avatar = data.avatar
|
||||
|
||||
# 修改密码
|
||||
if data.new_password:
|
||||
if not data.old_password:
|
||||
raise HTTPException(status_code=400, detail="请输入当前密码")
|
||||
if not pwd_context.verify(data.old_password, current_user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="当前密码错误")
|
||||
current_user.password_hash = pwd_context.hash(data.new_password)
|
||||
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
return UserResponse.model_validate(current_user)
|
||||
118
backend/routers/bookmarks.py
Normal file
118
backend/routers/bookmarks.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""网站收藏路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from database import get_db
|
||||
from models.bookmark import BookmarkSite
|
||||
from models.user import User
|
||||
from schemas.bookmark import BookmarkCreate, BookmarkUpdate, BookmarkResponse, ReorderRequest
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[BookmarkResponse])
|
||||
def get_bookmarks(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户的收藏网站列表"""
|
||||
bookmarks = (
|
||||
db.query(BookmarkSite)
|
||||
.filter(BookmarkSite.user_id == current_user.id)
|
||||
.order_by(BookmarkSite.sort_order, BookmarkSite.created_at)
|
||||
.all()
|
||||
)
|
||||
return [BookmarkResponse.model_validate(b) for b in bookmarks]
|
||||
|
||||
|
||||
@router.post("", response_model=BookmarkResponse)
|
||||
def create_bookmark(
|
||||
data: BookmarkCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""添加收藏网站"""
|
||||
# 获取当前最大排序值
|
||||
max_order = (
|
||||
db.query(BookmarkSite.sort_order)
|
||||
.filter(BookmarkSite.user_id == current_user.id)
|
||||
.order_by(BookmarkSite.sort_order.desc())
|
||||
.first()
|
||||
)
|
||||
next_order = (max_order[0] + 1) if max_order else 0
|
||||
|
||||
bookmark = BookmarkSite(
|
||||
user_id=current_user.id,
|
||||
name=data.name,
|
||||
url=data.url,
|
||||
icon=data.icon,
|
||||
sort_order=next_order,
|
||||
)
|
||||
db.add(bookmark)
|
||||
db.commit()
|
||||
db.refresh(bookmark)
|
||||
return BookmarkResponse.model_validate(bookmark)
|
||||
|
||||
|
||||
@router.put("/reorder")
|
||||
def reorder_bookmarks(
|
||||
data: ReorderRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""批量更新排序"""
|
||||
for item in data.items:
|
||||
db.query(BookmarkSite).filter(
|
||||
BookmarkSite.id == item.id,
|
||||
BookmarkSite.user_id == current_user.id,
|
||||
).update({"sort_order": item.sort_order})
|
||||
db.commit()
|
||||
return {"message": "排序已更新"}
|
||||
|
||||
|
||||
@router.put("/{bookmark_id}", response_model=BookmarkResponse)
|
||||
def update_bookmark(
|
||||
bookmark_id: int,
|
||||
data: BookmarkUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""编辑收藏网站"""
|
||||
bookmark = db.query(BookmarkSite).filter(
|
||||
BookmarkSite.id == bookmark_id,
|
||||
BookmarkSite.user_id == current_user.id,
|
||||
).first()
|
||||
if not bookmark:
|
||||
raise HTTPException(status_code=404, detail="收藏不存在")
|
||||
|
||||
if data.name is not None:
|
||||
bookmark.name = data.name
|
||||
if data.url is not None:
|
||||
bookmark.url = data.url
|
||||
if data.icon is not None:
|
||||
bookmark.icon = data.icon
|
||||
if data.sort_order is not None:
|
||||
bookmark.sort_order = data.sort_order
|
||||
|
||||
db.commit()
|
||||
db.refresh(bookmark)
|
||||
return BookmarkResponse.model_validate(bookmark)
|
||||
|
||||
|
||||
@router.delete("/{bookmark_id}")
|
||||
def delete_bookmark(
|
||||
bookmark_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除收藏网站"""
|
||||
bookmark = db.query(BookmarkSite).filter(
|
||||
BookmarkSite.id == bookmark_id,
|
||||
BookmarkSite.user_id == current_user.id,
|
||||
).first()
|
||||
if not bookmark:
|
||||
raise HTTPException(status_code=404, detail="收藏不存在")
|
||||
db.delete(bookmark)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
633
backend/routers/knowledge_base.py
Normal file
633
backend/routers/knowledge_base.py
Normal file
@@ -0,0 +1,633 @@
|
||||
"""团队知识库路由"""
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func as sa_func, or_
|
||||
from pydantic import BaseModel
|
||||
|
||||
from database import get_db
|
||||
from config import SECRET_KEY
|
||||
from models.user import User
|
||||
from models.post import Post
|
||||
from models.system_config import SystemConfig
|
||||
from models.knowledge_base import KbCategory, KbItem, KbAccessLog
|
||||
from models.attachment import Attachment
|
||||
from routers.auth import get_current_user, get_admin_user
|
||||
from services.ai_service import ai_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ========== 密码机制(复用 API Hub 方案) ==========
|
||||
|
||||
def _get_kb_password(db: Session) -> str:
|
||||
cfg = db.query(SystemConfig).filter(SystemConfig.key == "kb_password").first()
|
||||
return cfg.value if cfg else ""
|
||||
|
||||
|
||||
def _password_version(db: Session) -> str:
|
||||
"""返回密码哈希前8位作为版本标识,密码变更后旧token自动失效"""
|
||||
pwd = _get_kb_password(db)
|
||||
return pwd[:8] if pwd else "none"
|
||||
|
||||
|
||||
def _hash_password(pwd: str) -> str:
|
||||
return hashlib.sha256(pwd.encode()).hexdigest()
|
||||
|
||||
|
||||
def _create_kb_token(user_id: int, pwd_ver: str = "none") -> str:
|
||||
from jose import jwt
|
||||
exp = datetime.utcnow() + timedelta(hours=2)
|
||||
return jwt.encode({"sub": str(user_id), "kb": True, "pv": pwd_ver, "exp": exp}, SECRET_KEY, algorithm="HS256")
|
||||
|
||||
|
||||
def verify_kb_access(
|
||||
x_kb_token: Optional[str] = Header(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""验证用户登录 + 知识库访问令牌"""
|
||||
if not x_kb_token:
|
||||
raise HTTPException(status_code=403, detail="需要知识库访问权限,请先验证密码")
|
||||
from jose import jwt, JWTError
|
||||
try:
|
||||
payload = jwt.decode(x_kb_token, SECRET_KEY, algorithms=["HS256"])
|
||||
if not payload.get("kb"):
|
||||
raise HTTPException(status_code=403, detail="无效的知识库令牌")
|
||||
# 检查密码版本是否匹配
|
||||
token_pv = payload.get("pv", "")
|
||||
current_pv = _password_version(db)
|
||||
if token_pv != current_pv:
|
||||
raise HTTPException(status_code=403, detail="密码已变更,请重新验证")
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="知识库令牌已过期,请重新验证密码")
|
||||
return current_user
|
||||
|
||||
|
||||
# ========== Schemas ==========
|
||||
|
||||
class KbCategoryCreate(BaseModel):
|
||||
name: str
|
||||
icon: str = ""
|
||||
|
||||
class KbCategoryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class KbItemAdd(BaseModel):
|
||||
post_ids: List[int]
|
||||
category_id: Optional[int] = None
|
||||
|
||||
class KbItemUpdate(BaseModel):
|
||||
category_id: Optional[int] = None
|
||||
title: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class KbAiChatRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
class PasswordBody(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
# ========== 密码认证接口 ==========
|
||||
|
||||
@router.post("/auth")
|
||||
def kb_auth(
|
||||
body: PasswordBody,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""知识库密码验证"""
|
||||
stored = _get_kb_password(db)
|
||||
pv = _password_version(db)
|
||||
if not stored:
|
||||
# 未设置密码,直接放行
|
||||
token = _create_kb_token(current_user.id, pv)
|
||||
return {"token": token}
|
||||
if _hash_password(body.password) != stored:
|
||||
raise HTTPException(status_code=401, detail="密码错误")
|
||||
token = _create_kb_token(current_user.id, pv)
|
||||
return {"token": token}
|
||||
|
||||
|
||||
@router.get("/check-password")
|
||||
def kb_check_password(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""检查知识库是否需要密码"""
|
||||
stored = _get_kb_password(db)
|
||||
return {"has_password": bool(stored)}
|
||||
|
||||
|
||||
# ========== 公开接口(需登录 + kb_token) ==========
|
||||
|
||||
@router.get("/categories")
|
||||
def get_categories(
|
||||
user: User = Depends(verify_kb_access),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库分类列表"""
|
||||
cats = db.query(KbCategory).filter(KbCategory.is_active == True).order_by(KbCategory.sort_order, KbCategory.id).all()
|
||||
result = []
|
||||
for c in cats:
|
||||
count = db.query(sa_func.count(KbItem.id)).filter(KbItem.category_id == c.id, KbItem.is_active == True).scalar() or 0
|
||||
result.append({"id": c.id, "name": c.name, "icon": c.icon, "count": count})
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/items")
|
||||
def get_items(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
category_id: Optional[int] = None,
|
||||
keyword: Optional[str] = None,
|
||||
user: User = Depends(verify_kb_access),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库条目列表"""
|
||||
query = db.query(KbItem).filter(KbItem.is_active == True)
|
||||
if category_id:
|
||||
query = query.filter(KbItem.category_id == category_id)
|
||||
if keyword:
|
||||
kw = f"%{keyword}%"
|
||||
query = query.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
|
||||
total = query.count()
|
||||
items = query.order_by(KbItem.sort_order, KbItem.created_at.desc()).offset((page - 1) * size).limit(size).all()
|
||||
|
||||
result = []
|
||||
for item in items:
|
||||
post = db.query(Post).filter(Post.id == item.post_id).first()
|
||||
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
|
||||
author = db.query(User).filter(User.id == item.added_by).first() if item.added_by else None
|
||||
result.append({
|
||||
"id": item.id,
|
||||
"title": item.title,
|
||||
"summary": item.summary or (post.content[:150] if post else ""),
|
||||
"category_id": item.category_id,
|
||||
"category_name": cat.name if cat else "",
|
||||
"post_id": item.post_id,
|
||||
"post_author": post.user_id if post else None,
|
||||
"added_by_name": author.username if author else "",
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
})
|
||||
|
||||
# 记录访问日志
|
||||
if keyword:
|
||||
db.add(KbAccessLog(user_id=user.id, action="search", query=keyword))
|
||||
db.commit()
|
||||
|
||||
return {"items": result, "total": total, "page": page, "size": size}
|
||||
|
||||
|
||||
@router.get("/items/{item_id}")
|
||||
def get_item_detail(
|
||||
item_id: int,
|
||||
user: User = Depends(verify_kb_access),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库条目详情(含帖子完整内容)"""
|
||||
item = db.query(KbItem).filter(KbItem.id == item_id, KbItem.is_active == True).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="条目不存在")
|
||||
|
||||
post = db.query(Post).filter(Post.id == item.post_id).first()
|
||||
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
|
||||
post_author = db.query(User).filter(User.id == post.user_id).first() if post else None
|
||||
|
||||
# 记录访问
|
||||
db.add(KbAccessLog(user_id=user.id, action="view", query=str(item_id)))
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"id": item.id,
|
||||
"title": item.title,
|
||||
"summary": item.summary,
|
||||
"category_id": item.category_id,
|
||||
"category_name": cat.name if cat else "",
|
||||
"post_id": item.post_id,
|
||||
"post_title": post.title if post else "",
|
||||
"post_content": post.content if post else "",
|
||||
"post_author": {"id": post_author.id, "username": post_author.username} if post_author else None,
|
||||
"post_tags": post.tags if post else "",
|
||||
"post_category": post.category if post else "",
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
"attachments": [
|
||||
{"id": a.id, "filename": a.filename, "url": a.url, "file_size": a.file_size, "file_type": a.file_type}
|
||||
for a in db.query(Attachment).filter(Attachment.post_id == item.post_id).order_by(Attachment.created_at.asc()).all()
|
||||
] if post else [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
def get_kb_stats(
|
||||
user: User = Depends(verify_kb_access),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库统计"""
|
||||
total_items = db.query(sa_func.count(KbItem.id)).filter(KbItem.is_active == True).scalar() or 0
|
||||
total_categories = db.query(sa_func.count(KbCategory.id)).filter(KbCategory.is_active == True).scalar() or 0
|
||||
total_views = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "view").scalar() or 0
|
||||
total_searches = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "search").scalar() or 0
|
||||
total_ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "ai_chat").scalar() or 0
|
||||
return {
|
||||
"total_items": total_items,
|
||||
"total_categories": total_categories,
|
||||
"total_views": total_views,
|
||||
"total_searches": total_searches,
|
||||
"total_ai_chats": total_ai_chats,
|
||||
}
|
||||
|
||||
|
||||
# ========== AI 智能问答 ==========
|
||||
|
||||
KB_AI_SYSTEM_PROMPT = """你是一个团队知识库的AI助手。你的任务是根据知识库中的内容回答用户的问题。
|
||||
|
||||
规则:
|
||||
1. 只基于提供的知识库内容回答问题,不编造信息
|
||||
2. 如果知识库中没有相关内容,诚实告知用户
|
||||
3. 回答时引用来源文章的标题
|
||||
4. 使用 Markdown 格式组织回答
|
||||
5. 回答要简洁、专业、有条理"""
|
||||
|
||||
|
||||
@router.post("/ai-chat")
|
||||
async def kb_ai_chat(
|
||||
request: KbAiChatRequest,
|
||||
user: User = Depends(verify_kb_access),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""AI 智能问答(SSE 流式)"""
|
||||
question = request.question.strip()
|
||||
if not question:
|
||||
raise HTTPException(status_code=400, detail="请输入问题")
|
||||
|
||||
# 从知识库中检索相关内容作为上下文
|
||||
kw = f"%{question[:50]}%"
|
||||
# 简单关键词匹配(后续可升级为向量搜索)
|
||||
related_items = (
|
||||
db.query(KbItem)
|
||||
.filter(KbItem.is_active == True)
|
||||
.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
|
||||
.limit(5)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 如果关键词搜索结果不足,补充最新的条目
|
||||
if len(related_items) < 3:
|
||||
existing_ids = [item.id for item in related_items]
|
||||
extra = (
|
||||
db.query(KbItem)
|
||||
.filter(KbItem.is_active == True, ~KbItem.id.in_(existing_ids) if existing_ids else True)
|
||||
.order_by(KbItem.created_at.desc())
|
||||
.limit(10 - len(related_items))
|
||||
.all()
|
||||
)
|
||||
related_items.extend(extra)
|
||||
|
||||
# 构建知识库上下文
|
||||
context_parts = []
|
||||
for item in related_items:
|
||||
post = db.query(Post).filter(Post.id == item.post_id).first()
|
||||
if post:
|
||||
content = post.content[:2000] # 限制每篇长度
|
||||
context_parts.append(f"### {item.title}\n{content}")
|
||||
|
||||
kb_context = "\n\n---\n\n".join(context_parts) if context_parts else "知识库暂无相关内容。"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"以下是知识库中的相关内容:\n\n{kb_context}\n\n---\n\n用户问题:{question}"}
|
||||
]
|
||||
|
||||
# 记录日志
|
||||
db.add(KbAccessLog(user_id=user.id, action="ai_chat", query=question))
|
||||
db.commit()
|
||||
|
||||
# 流式调用AI
|
||||
async def generate():
|
||||
full_response = ""
|
||||
try:
|
||||
result = await ai_service.chat(
|
||||
task_type="reasoning",
|
||||
messages=messages,
|
||||
system_prompt=KB_AI_SYSTEM_PROMPT,
|
||||
stream=True,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
full_response = result
|
||||
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
|
||||
else:
|
||||
async for chunk in result:
|
||||
full_response += chunk
|
||||
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
|
||||
except Exception as e:
|
||||
error_msg = f"AI调用出错: {str(e)}"
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'content': '', 'done': True})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# ========== 管理员接口 ==========
|
||||
|
||||
# --- 密码管理 ---
|
||||
|
||||
@router.put("/admin/password")
|
||||
def set_kb_password(
|
||||
body: PasswordBody,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""设置/修改知识库密码"""
|
||||
hashed = _hash_password(body.password) if body.password else ""
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == "kb_password").first()
|
||||
if row:
|
||||
row.value = hashed
|
||||
else:
|
||||
db.add(SystemConfig(key="kb_password", value=hashed, description="知识库访问密码"))
|
||||
db.commit()
|
||||
return {"message": "密码已更新"}
|
||||
|
||||
|
||||
@router.get("/admin/password-status")
|
||||
def get_kb_password_status(
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库密码状态"""
|
||||
stored = _get_kb_password(db)
|
||||
return {"has_password": bool(stored)}
|
||||
|
||||
|
||||
# --- 分类管理 ---
|
||||
|
||||
@router.get("/admin/categories")
|
||||
def admin_list_categories(
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取所有知识库分类(含禁用)"""
|
||||
cats = db.query(KbCategory).order_by(KbCategory.sort_order, KbCategory.id).all()
|
||||
return [
|
||||
{
|
||||
"id": c.id, "name": c.name, "icon": c.icon,
|
||||
"sort_order": c.sort_order, "is_active": c.is_active,
|
||||
"item_count": db.query(sa_func.count(KbItem.id)).filter(KbItem.category_id == c.id).scalar() or 0,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
}
|
||||
for c in cats
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/categories")
|
||||
def admin_create_category(
|
||||
data: KbCategoryCreate,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建知识库分类"""
|
||||
exists = db.query(KbCategory).filter(KbCategory.name == data.name).first()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||||
cat = KbCategory(name=data.name, icon=data.icon)
|
||||
db.add(cat)
|
||||
db.commit()
|
||||
db.refresh(cat)
|
||||
return {"id": cat.id, "name": cat.name, "message": "创建成功"}
|
||||
|
||||
|
||||
@router.put("/admin/categories/{cat_id}")
|
||||
def admin_update_category(
|
||||
cat_id: int,
|
||||
data: KbCategoryUpdate,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新知识库分类"""
|
||||
cat = db.query(KbCategory).filter(KbCategory.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
updates = data.dict(exclude_none=True)
|
||||
for key, value in updates.items():
|
||||
setattr(cat, key, value)
|
||||
db.commit()
|
||||
return {"message": "更新成功"}
|
||||
|
||||
|
||||
@router.delete("/admin/categories/{cat_id}")
|
||||
def admin_delete_category(
|
||||
cat_id: int,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除知识库分类"""
|
||||
cat = db.query(KbCategory).filter(KbCategory.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
# 将该分类下的条目设为未分类
|
||||
db.query(KbItem).filter(KbItem.category_id == cat_id).update({"category_id": None})
|
||||
db.delete(cat)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
# --- 条目管理 ---
|
||||
|
||||
@router.get("/admin/items")
|
||||
def admin_list_items(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
category_id: Optional[int] = None,
|
||||
keyword: Optional[str] = None,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取所有知识库条目"""
|
||||
query = db.query(KbItem)
|
||||
if category_id:
|
||||
query = query.filter(KbItem.category_id == category_id)
|
||||
if keyword:
|
||||
kw = f"%{keyword}%"
|
||||
query = query.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
|
||||
total = query.count()
|
||||
items = query.order_by(KbItem.created_at.desc()).offset((page - 1) * size).limit(size).all()
|
||||
|
||||
result = []
|
||||
for item in items:
|
||||
post = db.query(Post).filter(Post.id == item.post_id).first()
|
||||
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
|
||||
author = db.query(User).filter(User.id == item.added_by).first() if item.added_by else None
|
||||
post_author = db.query(User).filter(User.id == post.user_id).first() if post else None
|
||||
result.append({
|
||||
"id": item.id,
|
||||
"title": item.title,
|
||||
"summary": item.summary or "",
|
||||
"category_id": item.category_id,
|
||||
"category_name": cat.name if cat else "未分类",
|
||||
"post_id": item.post_id,
|
||||
"post_title": post.title if post else "(已删除)",
|
||||
"post_author_name": post_author.username if post_author else "",
|
||||
"added_by_name": author.username if author else "",
|
||||
"is_active": item.is_active,
|
||||
"sort_order": item.sort_order,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
})
|
||||
return {"items": result, "total": total, "page": page, "size": size}
|
||||
|
||||
|
||||
@router.post("/admin/items")
|
||||
def admin_add_items(
|
||||
data: KbItemAdd,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""从帖子批量添加到知识库"""
|
||||
added = 0
|
||||
skipped = 0
|
||||
for post_id in data.post_ids:
|
||||
# 检查是否已存在
|
||||
exists = db.query(KbItem).filter(KbItem.post_id == post_id).first()
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
if not post:
|
||||
skipped += 1
|
||||
continue
|
||||
item = KbItem(
|
||||
post_id=post_id,
|
||||
category_id=data.category_id,
|
||||
title=post.title,
|
||||
summary=post.content[:200] if post.content else "",
|
||||
added_by=admin.id,
|
||||
)
|
||||
db.add(item)
|
||||
added += 1
|
||||
db.commit()
|
||||
return {"message": f"已添加 {added} 条,跳过 {skipped} 条(已存在或不存在)", "added": added, "skipped": skipped}
|
||||
|
||||
|
||||
@router.put("/admin/items/{item_id}")
|
||||
def admin_update_item(
|
||||
item_id: int,
|
||||
data: KbItemUpdate,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新知识库条目"""
|
||||
item = db.query(KbItem).filter(KbItem.id == item_id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="条目不存在")
|
||||
updates = data.dict(exclude_none=True)
|
||||
for key, value in updates.items():
|
||||
setattr(item, key, value)
|
||||
db.commit()
|
||||
return {"message": "更新成功"}
|
||||
|
||||
|
||||
@router.delete("/admin/items/{item_id}")
|
||||
def admin_delete_item(
|
||||
item_id: int,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除知识库条目"""
|
||||
item = db.query(KbItem).filter(KbItem.id == item_id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="条目不存在")
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.get("/admin/posts-for-pick")
|
||||
def admin_posts_for_pick(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=50),
|
||||
keyword: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取可选帖子列表(排除已加入知识库的)"""
|
||||
existing_post_ids = [r[0] for r in db.query(KbItem.post_id).all()]
|
||||
query = db.query(Post).filter(Post.is_public == True)
|
||||
if existing_post_ids:
|
||||
query = query.filter(~Post.id.in_(existing_post_ids))
|
||||
if keyword:
|
||||
kw = f"%{keyword}%"
|
||||
query = query.filter(or_(Post.title.like(kw), Post.content.like(kw)))
|
||||
if category:
|
||||
query = query.filter(Post.category == category)
|
||||
total = query.count()
|
||||
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * size).limit(size).all()
|
||||
|
||||
result = []
|
||||
for p in posts:
|
||||
author = db.query(User).filter(User.id == p.user_id).first()
|
||||
result.append({
|
||||
"id": p.id,
|
||||
"title": p.title,
|
||||
"category": p.category,
|
||||
"content_preview": p.content[:100] if p.content else "",
|
||||
"author_name": author.username if author else "",
|
||||
"like_count": p.like_count,
|
||||
"view_count": p.view_count,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
})
|
||||
return {"items": result, "total": total, "page": page, "size": size}
|
||||
|
||||
|
||||
# --- 管理员统计 ---
|
||||
|
||||
@router.get("/admin/stats")
|
||||
def admin_kb_stats(
|
||||
admin: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""管理员统计数据"""
|
||||
total_items = db.query(sa_func.count(KbItem.id)).scalar() or 0
|
||||
active_items = db.query(sa_func.count(KbItem.id)).filter(KbItem.is_active == True).scalar() or 0
|
||||
total_categories = db.query(sa_func.count(KbCategory.id)).scalar() or 0
|
||||
total_views = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "view").scalar() or 0
|
||||
total_searches = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "search").scalar() or 0
|
||||
total_ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "ai_chat").scalar() or 0
|
||||
|
||||
# 最近7天趋势
|
||||
from datetime import date
|
||||
today = date.today()
|
||||
daily_stats = []
|
||||
for i in range(6, -1, -1):
|
||||
d = today - timedelta(days=i)
|
||||
day_start = datetime.combine(d, datetime.min.time())
|
||||
day_end = datetime.combine(d, datetime.max.time())
|
||||
views = db.query(sa_func.count(KbAccessLog.id)).filter(
|
||||
KbAccessLog.action == "view", KbAccessLog.created_at.between(day_start, day_end)
|
||||
).scalar() or 0
|
||||
searches = db.query(sa_func.count(KbAccessLog.id)).filter(
|
||||
KbAccessLog.action == "search", KbAccessLog.created_at.between(day_start, day_end)
|
||||
).scalar() or 0
|
||||
ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(
|
||||
KbAccessLog.action == "ai_chat", KbAccessLog.created_at.between(day_start, day_end)
|
||||
).scalar() or 0
|
||||
daily_stats.append({"date": d.isoformat(), "views": views, "searches": searches, "ai_chats": ai_chats})
|
||||
|
||||
return {
|
||||
"total_items": total_items,
|
||||
"active_items": active_items,
|
||||
"total_categories": total_categories,
|
||||
"total_views": total_views,
|
||||
"total_searches": total_searches,
|
||||
"total_ai_chats": total_ai_chats,
|
||||
"daily_stats": daily_stats,
|
||||
}
|
||||
360
backend/routers/nav.py
Normal file
360
backend/routers/nav.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""导航站路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func as sa_func
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.nav_category import NavCategory
|
||||
from models.nav_link import NavLink
|
||||
from routers.auth import get_current_user, get_admin_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ========== Schemas ==========
|
||||
|
||||
class NavCategoryCreate(BaseModel):
|
||||
name: str
|
||||
icon: str = ""
|
||||
|
||||
class NavCategoryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class NavLinkCreate(BaseModel):
|
||||
category_id: int
|
||||
name: str
|
||||
url: str
|
||||
icon: str = ""
|
||||
description: str = ""
|
||||
|
||||
class NavLinkUpdate(BaseModel):
|
||||
category_id: Optional[int] = None
|
||||
name: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class NavLinkSubmit(BaseModel):
|
||||
"""用户提交导航链接"""
|
||||
category_id: int
|
||||
name: str
|
||||
url: str
|
||||
icon: str = ""
|
||||
description: str = ""
|
||||
|
||||
class NavLinkReview(BaseModel):
|
||||
"""审核操作"""
|
||||
action: str # approve / reject
|
||||
reject_reason: str = ""
|
||||
|
||||
|
||||
# ========== 管理员接口 ==========
|
||||
|
||||
@router.get("/admin/categories")
|
||||
def admin_list_categories(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取所有导航分类(含禁用)"""
|
||||
cats = db.query(NavCategory).order_by(NavCategory.sort_order, NavCategory.id).all()
|
||||
return [
|
||||
{
|
||||
"id": c.id, "name": c.name, "icon": c.icon,
|
||||
"sort_order": c.sort_order, "is_active": c.is_active,
|
||||
"link_count": db.query(sa_func.count(NavLink.id)).filter(NavLink.category_id == c.id).scalar() or 0,
|
||||
}
|
||||
for c in cats
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/categories")
|
||||
def admin_create_category(
|
||||
data: NavCategoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""新增导航分类"""
|
||||
existing = db.query(NavCategory).filter(NavCategory.name == data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||||
max_order = db.query(sa_func.max(NavCategory.sort_order)).scalar() or 0
|
||||
cat = NavCategory(name=data.name, icon=data.icon, sort_order=max_order + 1)
|
||||
db.add(cat)
|
||||
db.commit()
|
||||
db.refresh(cat)
|
||||
return {"id": cat.id, "name": cat.name, "icon": cat.icon, "sort_order": cat.sort_order, "is_active": cat.is_active}
|
||||
|
||||
|
||||
@router.put("/admin/categories/{cat_id}")
|
||||
def admin_update_category(
|
||||
cat_id: int,
|
||||
data: NavCategoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""编辑导航分类"""
|
||||
cat = db.query(NavCategory).filter(NavCategory.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
if data.name is not None:
|
||||
dup = db.query(NavCategory).filter(NavCategory.name == data.name, NavCategory.id != cat_id).first()
|
||||
if dup:
|
||||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||||
cat.name = data.name
|
||||
if data.icon is not None:
|
||||
cat.icon = data.icon
|
||||
if data.sort_order is not None:
|
||||
cat.sort_order = data.sort_order
|
||||
if data.is_active is not None:
|
||||
cat.is_active = data.is_active
|
||||
db.commit()
|
||||
db.refresh(cat)
|
||||
return {"id": cat.id, "name": cat.name, "icon": cat.icon, "sort_order": cat.sort_order, "is_active": cat.is_active}
|
||||
|
||||
|
||||
@router.delete("/admin/categories/{cat_id}")
|
||||
def admin_delete_category(
|
||||
cat_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""删除导航分类(级联删除链接)"""
|
||||
cat = db.query(NavCategory).filter(NavCategory.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
db.query(NavLink).filter(NavLink.category_id == cat_id).delete()
|
||||
db.delete(cat)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.get("/admin/links")
|
||||
def admin_list_links(
|
||||
category_id: Optional[int] = None,
|
||||
status: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取导航链接列表"""
|
||||
query = db.query(NavLink)
|
||||
if category_id is not None:
|
||||
query = query.filter(NavLink.category_id == category_id)
|
||||
if status is not None:
|
||||
query = query.filter(NavLink.status == status)
|
||||
links = query.order_by(NavLink.sort_order, NavLink.id).all()
|
||||
return [
|
||||
{
|
||||
"id": l.id, "category_id": l.category_id, "name": l.name,
|
||||
"url": l.url, "icon": l.icon, "description": l.description,
|
||||
"sort_order": l.sort_order, "is_active": l.is_active,
|
||||
"status": l.status, "submitted_by": l.submitted_by,
|
||||
"reject_reason": l.reject_reason or "",
|
||||
}
|
||||
for l in links
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/links")
|
||||
def admin_create_link(
|
||||
data: NavLinkCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""新增导航链接"""
|
||||
cat = db.query(NavCategory).filter(NavCategory.id == data.category_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=400, detail="分类不存在")
|
||||
max_order = db.query(sa_func.max(NavLink.sort_order)).filter(NavLink.category_id == data.category_id).scalar() or 0
|
||||
link = NavLink(
|
||||
category_id=data.category_id, name=data.name, url=data.url,
|
||||
icon=data.icon, description=data.description, sort_order=max_order + 1,
|
||||
)
|
||||
db.add(link)
|
||||
db.commit()
|
||||
db.refresh(link)
|
||||
return {
|
||||
"id": link.id, "category_id": link.category_id, "name": link.name,
|
||||
"url": link.url, "icon": link.icon, "description": link.description,
|
||||
"sort_order": link.sort_order, "is_active": link.is_active,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/admin/links/{link_id}")
|
||||
def admin_update_link(
|
||||
link_id: int,
|
||||
data: NavLinkUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""编辑导航链接"""
|
||||
link = db.query(NavLink).filter(NavLink.id == link_id).first()
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="链接不存在")
|
||||
if data.category_id is not None:
|
||||
link.category_id = data.category_id
|
||||
if data.name is not None:
|
||||
link.name = data.name
|
||||
if data.url is not None:
|
||||
link.url = data.url
|
||||
if data.icon is not None:
|
||||
link.icon = data.icon
|
||||
if data.description is not None:
|
||||
link.description = data.description
|
||||
if data.sort_order is not None:
|
||||
link.sort_order = data.sort_order
|
||||
if data.is_active is not None:
|
||||
link.is_active = data.is_active
|
||||
db.commit()
|
||||
db.refresh(link)
|
||||
return {
|
||||
"id": link.id, "category_id": link.category_id, "name": link.name,
|
||||
"url": link.url, "icon": link.icon, "description": link.description,
|
||||
"sort_order": link.sort_order, "is_active": link.is_active,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/admin/links/{link_id}")
|
||||
def admin_delete_link(
|
||||
link_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""删除导航链接"""
|
||||
link = db.query(NavLink).filter(NavLink.id == link_id).first()
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="链接不存在")
|
||||
db.delete(link)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.get("/admin/pending-count")
|
||||
def admin_pending_count(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取待审核数量"""
|
||||
count = db.query(sa_func.count(NavLink.id)).filter(NavLink.status == "pending").scalar() or 0
|
||||
return {"count": count}
|
||||
|
||||
|
||||
@router.put("/admin/links/{link_id}/review")
|
||||
def admin_review_link(
|
||||
link_id: int,
|
||||
data: NavLinkReview,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""审核导航链接"""
|
||||
link = db.query(NavLink).filter(NavLink.id == link_id).first()
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="链接不存在")
|
||||
if data.action == "approve":
|
||||
link.status = "approved"
|
||||
link.is_active = True
|
||||
link.reject_reason = ""
|
||||
elif data.action == "reject":
|
||||
link.status = "rejected"
|
||||
link.is_active = False
|
||||
link.reject_reason = data.reject_reason
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="无效操作,请使用 approve 或 reject")
|
||||
db.commit()
|
||||
db.refresh(link)
|
||||
return {
|
||||
"id": link.id, "status": link.status,
|
||||
"is_active": link.is_active, "reject_reason": link.reject_reason or "",
|
||||
}
|
||||
|
||||
|
||||
# ========== 用户提交接口 ==========
|
||||
|
||||
@router.post("/submit")
|
||||
def user_submit_link(
|
||||
data: NavLinkSubmit,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""用户提交导航网站(需管理员审核)"""
|
||||
cat = db.query(NavCategory).filter(NavCategory.id == data.category_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=400, detail="分类不存在")
|
||||
existing = db.query(NavLink).filter(NavLink.url == data.url).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该网站已被提交过")
|
||||
link = NavLink(
|
||||
category_id=data.category_id, name=data.name, url=data.url,
|
||||
icon=data.icon, description=data.description,
|
||||
status="pending", submitted_by=user.id, is_active=False,
|
||||
)
|
||||
db.add(link)
|
||||
db.commit()
|
||||
db.refresh(link)
|
||||
return {
|
||||
"id": link.id, "name": link.name, "url": link.url,
|
||||
"status": link.status, "message": "提交成功,等待管理员审核",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/my-submissions")
|
||||
def user_my_submissions(
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""用户查看自己提交的记录"""
|
||||
links = db.query(NavLink).filter(NavLink.submitted_by == user.id).order_by(NavLink.id.desc()).all()
|
||||
return [
|
||||
{
|
||||
"id": l.id, "name": l.name, "url": l.url, "icon": l.icon,
|
||||
"description": l.description, "status": l.status,
|
||||
"reject_reason": l.reject_reason or "",
|
||||
"created_at": l.created_at.isoformat() if l.created_at else "",
|
||||
}
|
||||
for l in links
|
||||
]
|
||||
|
||||
|
||||
# ========== 公开接口 ==========
|
||||
|
||||
@router.get("/public")
|
||||
def get_public_nav(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取所有启用分类及其下启用的链接"""
|
||||
cats = db.query(NavCategory).filter(NavCategory.is_active == True).order_by(NavCategory.sort_order, NavCategory.id).all()
|
||||
result = []
|
||||
for c in cats:
|
||||
links = (
|
||||
db.query(NavLink)
|
||||
.filter(NavLink.category_id == c.id, NavLink.is_active == True, NavLink.status == "approved")
|
||||
.order_by(NavLink.sort_order, NavLink.id)
|
||||
.all()
|
||||
)
|
||||
if links:
|
||||
result.append({
|
||||
"id": c.id, "name": c.name, "icon": c.icon,
|
||||
"links": [
|
||||
{"id": l.id, "name": l.name, "url": l.url, "icon": l.icon, "description": l.description}
|
||||
for l in links
|
||||
],
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/public/categories")
|
||||
def get_public_categories(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取所有启用分类(供用户提交时选择)"""
|
||||
cats = db.query(NavCategory).filter(NavCategory.is_active == True).order_by(NavCategory.sort_order, NavCategory.id).all()
|
||||
return [{"id": c.id, "name": c.name, "icon": c.icon} for c in cats]
|
||||
89
backend/routers/notifications.py
Normal file
89
backend/routers/notifications.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""消息通知路由"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.notification import Notification
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_notifications(
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取通知列表"""
|
||||
notifs = (
|
||||
db.query(Notification)
|
||||
.filter(Notification.user_id == current_user.id)
|
||||
.order_by(Notification.created_at.desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
|
||||
# 获取触发用户信息
|
||||
from_ids = list(set(n.from_user_id for n in notifs if n.from_user_id))
|
||||
users = db.query(User).filter(User.id.in_(from_ids)).all() if from_ids else []
|
||||
user_map = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
{
|
||||
"id": n.id,
|
||||
"type": n.type,
|
||||
"content": n.content,
|
||||
"related_id": n.related_id,
|
||||
"is_read": n.is_read,
|
||||
"created_at": n.created_at,
|
||||
"from_user": {
|
||||
"id": n.from_user_id,
|
||||
"username": user_map[n.from_user_id].username,
|
||||
"avatar": user_map[n.from_user_id].avatar,
|
||||
} if n.from_user_id and n.from_user_id in user_map else None,
|
||||
}
|
||||
for n in notifs
|
||||
]
|
||||
|
||||
|
||||
@router.get("/unread-count")
|
||||
def get_unread_count(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取未读通知数量"""
|
||||
count = (
|
||||
db.query(func.count(Notification.id))
|
||||
.filter(Notification.user_id == current_user.id, Notification.is_read == False)
|
||||
.scalar()
|
||||
)
|
||||
return {"count": count}
|
||||
|
||||
|
||||
@router.put("/read-all")
|
||||
def read_all(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""全部标为已读"""
|
||||
db.query(Notification).filter(
|
||||
Notification.user_id == current_user.id, Notification.is_read == False
|
||||
).update({"is_read": True})
|
||||
db.commit()
|
||||
return {"message": "已全部标为已读"}
|
||||
|
||||
|
||||
@router.put("/{notif_id}/read")
|
||||
def read_one(
|
||||
notif_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""单条标为已读"""
|
||||
db.query(Notification).filter(
|
||||
Notification.id == notif_id, Notification.user_id == current_user.id
|
||||
).update({"is_read": True})
|
||||
db.commit()
|
||||
return {"message": "已标为已读"}
|
||||
440
backend/routers/posts.py
Normal file
440
backend/routers/posts.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""经验知识库路由"""
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.post import Post
|
||||
from models.comment import Comment
|
||||
from models.like import Like, Collect
|
||||
from models.follow import Follow
|
||||
from models.notification import Notification
|
||||
from models.attachment import Attachment
|
||||
from schemas.post import (
|
||||
PostCreate, PostUpdate, PostResponse, PostListResponse,
|
||||
CommentCreate, CommentResponse,
|
||||
)
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _extract_cover_image(content: str) -> str:
|
||||
"""从Markdown内容中提取第一张图片作为封面"""
|
||||
if not content:
|
||||
return ""
|
||||
match = re.search(r'!\[.*?\]\((.*?)\)', content)
|
||||
if match:
|
||||
return match.group(1)
|
||||
img_match = re.search(r'<img[^>]+src=["\']([^"\']+)["\']', content)
|
||||
if img_match:
|
||||
return img_match.group(1)
|
||||
return ""
|
||||
|
||||
|
||||
def _enrich_post_with_author(post: Post, db: Session) -> dict:
|
||||
"""为帖子附加作者信息(用于信息流)"""
|
||||
author = db.query(User).filter(User.id == post.user_id).first()
|
||||
return {
|
||||
"id": post.id, "title": post.title, "content": post.content[:200],
|
||||
"category": post.category, "tags": post.tags,
|
||||
"cover_image": _extract_cover_image(post.content),
|
||||
"view_count": post.view_count, "like_count": post.like_count,
|
||||
"comment_count": post.comment_count, "collect_count": post.collect_count,
|
||||
"created_at": post.created_at, "updated_at": post.updated_at,
|
||||
"author": {
|
||||
"id": post.user_id,
|
||||
"username": author.username if author else "未知",
|
||||
"avatar": author.avatar if author else "",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/feed")
|
||||
def get_feed(
|
||||
page: int = 1, page_size: int = 20,
|
||||
category: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""关注的人的帖子流"""
|
||||
following_ids = [f.following_id for f in db.query(Follow.following_id).filter(Follow.follower_id == current_user.id).all()]
|
||||
if not following_ids:
|
||||
return {"items": [], "total": 0, "page": page, "page_size": page_size}
|
||||
query = db.query(Post).filter(Post.user_id.in_(following_ids), Post.is_public == True, Post.is_draft == False)
|
||||
if category:
|
||||
query = query.filter(Post.category == category)
|
||||
total = query.count()
|
||||
posts = (
|
||||
query.order_by(Post.created_at.desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
@router.get("/hot")
|
||||
def get_hot_posts(
|
||||
page: int = 1, page_size: int = 20,
|
||||
category: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""热门帖子(按热度排序)"""
|
||||
query = db.query(Post).filter(Post.is_public == True, Post.is_draft == False)
|
||||
if category:
|
||||
query = query.filter(Post.category == category)
|
||||
total = query.count()
|
||||
posts = (
|
||||
query.order_by((Post.like_count * 3 + Post.comment_count * 2 + Post.view_count).desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
@router.get("/latest")
|
||||
def get_latest_posts(
|
||||
page: int = 1, page_size: int = 20,
|
||||
category: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""最新帖子"""
|
||||
query = db.query(Post).filter(Post.is_public == True, Post.is_draft == False)
|
||||
if category:
|
||||
query = query.filter(Post.category == category)
|
||||
total = query.count()
|
||||
posts = (
|
||||
query.order_by(Post.created_at.desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
@router.get("/drafts")
|
||||
def get_drafts(
|
||||
page: int = 1, page_size: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户的草稿列表"""
|
||||
query = db.query(Post).filter(Post.user_id == current_user.id, Post.is_draft == True)
|
||||
total = query.count()
|
||||
posts = query.order_by(Post.updated_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return {
|
||||
"items": [_enrich_post(p, db, current_user.id) for p in posts],
|
||||
"total": total, "page": page, "page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
def _enrich_post(post: Post, db: Session, current_user_id: int = None) -> PostResponse:
|
||||
"""填充帖子额外字段"""
|
||||
author = db.query(User).filter(User.id == post.user_id).first()
|
||||
result = PostResponse.model_validate(post)
|
||||
result.author_name = author.username if author else "未知用户"
|
||||
if current_user_id:
|
||||
result.is_liked = db.query(Like).filter(
|
||||
Like.post_id == post.id, Like.user_id == current_user_id
|
||||
).first() is not None
|
||||
result.is_collected = db.query(Collect).filter(
|
||||
Collect.post_id == post.id, Collect.user_id == current_user_id
|
||||
).first() is not None
|
||||
return result
|
||||
|
||||
|
||||
@router.get("", response_model=PostListResponse)
|
||||
def get_posts(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
category: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
user_id: Optional[int] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取帖子列表"""
|
||||
query = db.query(Post).filter(
|
||||
or_(Post.is_public == True, Post.user_id == current_user.id),
|
||||
Post.is_draft == False,
|
||||
)
|
||||
if category:
|
||||
query = query.filter(Post.category == category)
|
||||
if tag:
|
||||
query = query.filter(Post.tags.contains(tag))
|
||||
if user_id:
|
||||
query = query.filter(Post.user_id == user_id)
|
||||
|
||||
total = query.count()
|
||||
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
return PostListResponse(
|
||||
items=[_enrich_post(p, db, current_user.id) for p in posts],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=PostResponse)
|
||||
def create_post(
|
||||
data: PostCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""发布经验帖"""
|
||||
post = Post(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
content=data.content,
|
||||
category=data.category,
|
||||
tags=json.dumps(data.tags, ensure_ascii=False),
|
||||
is_public=data.is_public,
|
||||
is_draft=data.is_draft,
|
||||
)
|
||||
db.add(post)
|
||||
db.commit()
|
||||
db.refresh(post)
|
||||
return _enrich_post(post, db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=PostResponse)
|
||||
def get_post(
|
||||
post_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取帖子详情"""
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
if not post.is_public and post.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权访问")
|
||||
|
||||
# 增加浏览量
|
||||
post.view_count += 1
|
||||
db.commit()
|
||||
db.refresh(post)
|
||||
|
||||
return _enrich_post(post, db, current_user.id)
|
||||
|
||||
|
||||
@router.put("/{post_id}", response_model=PostResponse)
|
||||
def update_post(
|
||||
post_id: int,
|
||||
data: PostUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""编辑帖子"""
|
||||
# 管理员可以编辑任意帖子,普通用户只能编辑自己的
|
||||
if current_user.is_admin:
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
else:
|
||||
post = db.query(Post).filter(Post.id == post_id, Post.user_id == current_user.id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在或无权编辑")
|
||||
|
||||
if data.title is not None:
|
||||
post.title = data.title
|
||||
if data.content is not None:
|
||||
post.content = data.content
|
||||
if data.category is not None:
|
||||
post.category = data.category
|
||||
if data.tags is not None:
|
||||
post.tags = json.dumps(data.tags, ensure_ascii=False)
|
||||
if data.is_public is not None:
|
||||
post.is_public = data.is_public
|
||||
if data.is_draft is not None:
|
||||
post.is_draft = data.is_draft
|
||||
|
||||
db.commit()
|
||||
db.refresh(post)
|
||||
return _enrich_post(post, db, current_user.id)
|
||||
|
||||
|
||||
@router.delete("/{post_id}")
|
||||
def delete_post(
|
||||
post_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除帖子"""
|
||||
post = db.query(Post).filter(Post.id == post_id, Post.user_id == current_user.id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在或无权删除")
|
||||
|
||||
db.query(Comment).filter(Comment.post_id == post_id).delete()
|
||||
db.query(Like).filter(Like.post_id == post_id).delete()
|
||||
db.query(Collect).filter(Collect.post_id == post_id).delete()
|
||||
db.query(Attachment).filter(Attachment.post_id == post_id).delete()
|
||||
db.delete(post)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/like")
|
||||
def toggle_like(
|
||||
post_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""点赞/取消点赞"""
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
|
||||
existing = db.query(Like).filter(
|
||||
Like.post_id == post_id, Like.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
post.like_count = max(0, post.like_count - 1)
|
||||
db.commit()
|
||||
return {"liked": False, "like_count": post.like_count}
|
||||
else:
|
||||
db.add(Like(post_id=post_id, user_id=current_user.id))
|
||||
post.like_count += 1
|
||||
# 通知帖子作者
|
||||
if post.user_id != current_user.id:
|
||||
db.add(Notification(
|
||||
user_id=post.user_id, type="like",
|
||||
content=f"{current_user.username} 赞了你的文章「{post.title[:30]}」",
|
||||
from_user_id=current_user.id, related_id=post_id,
|
||||
))
|
||||
db.commit()
|
||||
return {"liked": True, "like_count": post.like_count}
|
||||
|
||||
|
||||
@router.post("/{post_id}/collect")
|
||||
def toggle_collect(
|
||||
post_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""收藏/取消收藏"""
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
|
||||
existing = db.query(Collect).filter(
|
||||
Collect.post_id == post_id, Collect.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
post.collect_count = max(0, post.collect_count - 1)
|
||||
db.commit()
|
||||
return {"collected": False, "collect_count": post.collect_count}
|
||||
else:
|
||||
db.add(Collect(post_id=post_id, user_id=current_user.id))
|
||||
post.collect_count += 1
|
||||
db.commit()
|
||||
return {"collected": True, "collect_count": post.collect_count}
|
||||
|
||||
|
||||
@router.get("/{post_id}/comments", response_model=List[CommentResponse])
|
||||
def get_comments(
|
||||
post_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取评论列表"""
|
||||
comments = (
|
||||
db.query(Comment)
|
||||
.filter(Comment.post_id == post_id)
|
||||
.order_by(Comment.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
results = []
|
||||
for c in comments:
|
||||
author = db.query(User).filter(User.id == c.user_id).first()
|
||||
r = CommentResponse.model_validate(c)
|
||||
r.author_name = author.username if author else "未知用户"
|
||||
results.append(r)
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/{post_id}/comments", response_model=CommentResponse)
|
||||
def create_comment(
|
||||
post_id: int,
|
||||
data: CommentCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""发表评论"""
|
||||
post = db.query(Post).filter(Post.id == post_id).first()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
|
||||
comment = Comment(
|
||||
post_id=post_id,
|
||||
user_id=current_user.id,
|
||||
content=data.content,
|
||||
)
|
||||
db.add(comment)
|
||||
post.comment_count += 1
|
||||
# 通知帖子作者
|
||||
if post.user_id != current_user.id:
|
||||
db.add(Notification(
|
||||
user_id=post.user_id, type="comment",
|
||||
content=f"{current_user.username} 评论了你的文章「{post.title[:30]}」",
|
||||
from_user_id=current_user.id, related_id=post_id,
|
||||
))
|
||||
db.commit()
|
||||
db.refresh(comment)
|
||||
|
||||
result = CommentResponse.model_validate(comment)
|
||||
result.author_name = current_user.username
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{post_id}/attachments")
|
||||
def get_attachments(
|
||||
post_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取帖子的附件列表"""
|
||||
attachments = (
|
||||
db.query(Attachment)
|
||||
.filter(Attachment.post_id == post_id)
|
||||
.order_by(Attachment.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": a.id,
|
||||
"filename": a.filename,
|
||||
"url": a.url,
|
||||
"file_size": a.file_size,
|
||||
"file_type": a.file_type,
|
||||
"created_at": a.created_at,
|
||||
}
|
||||
for a in attachments
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{post_id}/attachments/{attachment_id}")
|
||||
def delete_attachment(
|
||||
post_id: int,
|
||||
attachment_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除附件(仅作者)"""
|
||||
attachment = db.query(Attachment).filter(
|
||||
Attachment.id == attachment_id,
|
||||
Attachment.post_id == post_id,
|
||||
Attachment.user_id == current_user.id,
|
||||
).first()
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="附件不存在或无权删除")
|
||||
db.delete(attachment)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
410
backend/routers/projects.py
Normal file
410
backend/routers/projects.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""开源项目路由"""
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func as sa_func, or_
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.project import Project
|
||||
from models.like import ProjectCollect
|
||||
from routers.auth import get_current_user, get_admin_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
GITHUB_API = "https://api.github.com"
|
||||
|
||||
|
||||
# ========== Schemas ==========
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
url: str
|
||||
homepage: str = ""
|
||||
icon: str = ""
|
||||
language: str = ""
|
||||
category: str = ""
|
||||
stars: int = 0
|
||||
forks: int = 0
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
homepage: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
stars: Optional[int] = None
|
||||
forks: Optional[int] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
def _project_to_dict(p: Project, is_collected: bool = False) -> dict:
|
||||
return {
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"description": p.description or "",
|
||||
"url": p.url,
|
||||
"homepage": p.homepage or "",
|
||||
"icon": p.icon or "",
|
||||
"language": p.language or "",
|
||||
"category": p.category or "",
|
||||
"stars": p.stars or 0,
|
||||
"forks": p.forks or 0,
|
||||
"collect_count": getattr(p, 'collect_count', 0) or 0,
|
||||
"is_collected": is_collected,
|
||||
"sort_order": p.sort_order,
|
||||
"is_active": p.is_active,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _with_collect_status(items: list, user_id: int, db: Session) -> list:
|
||||
"""批量查询用户是否已收藏"""
|
||||
if not items:
|
||||
return []
|
||||
project_ids = [p.id for p in items]
|
||||
collected_ids = set(
|
||||
r[0] for r in db.query(ProjectCollect.project_id)
|
||||
.filter(ProjectCollect.project_id.in_(project_ids), ProjectCollect.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
return [_project_to_dict(p, p.id in collected_ids) for p in items]
|
||||
|
||||
|
||||
# ========== 管理员接口 ==========
|
||||
|
||||
@router.get("/admin/list")
|
||||
def admin_list_projects(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
category: Optional[str] = None,
|
||||
keyword: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""获取所有项目(含禁用),支持分页"""
|
||||
query = db.query(Project)
|
||||
if category:
|
||||
query = query.filter(Project.category == category)
|
||||
if keyword:
|
||||
kw = f"%{keyword}%"
|
||||
query = query.filter(or_(Project.name.like(kw), Project.description.like(kw)))
|
||||
total = query.count()
|
||||
items = query.order_by(Project.sort_order, Project.id.desc()).offset((page - 1) * size).limit(size).all()
|
||||
return {
|
||||
"total": total,
|
||||
"items": [_project_to_dict(p) for p in items],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/admin")
|
||||
def admin_create_project(
|
||||
data: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""新增项目"""
|
||||
max_order = db.query(sa_func.max(Project.sort_order)).scalar() or 0
|
||||
proj = Project(
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
url=data.url,
|
||||
homepage=data.homepage,
|
||||
icon=data.icon,
|
||||
language=data.language,
|
||||
category=data.category,
|
||||
stars=data.stars,
|
||||
forks=data.forks,
|
||||
sort_order=max_order + 1,
|
||||
)
|
||||
db.add(proj)
|
||||
db.commit()
|
||||
db.refresh(proj)
|
||||
return _project_to_dict(proj)
|
||||
|
||||
|
||||
@router.put("/admin/{project_id}")
|
||||
def admin_update_project(
|
||||
project_id: int,
|
||||
data: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""编辑项目"""
|
||||
proj = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
for field in ["name", "description", "url", "homepage", "icon", "language", "category", "stars", "forks", "sort_order", "is_active"]:
|
||||
val = getattr(data, field)
|
||||
if val is not None:
|
||||
setattr(proj, field, val)
|
||||
db.commit()
|
||||
db.refresh(proj)
|
||||
return _project_to_dict(proj)
|
||||
|
||||
|
||||
@router.delete("/admin/{project_id}")
|
||||
def admin_delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""删除项目"""
|
||||
proj = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
db.delete(proj)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
# ========== 公开接口 ==========
|
||||
|
||||
@router.get("/hot")
|
||||
def get_hot_projects(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(12, ge=1, le=50),
|
||||
category: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""热门项目(按 stars 降序)"""
|
||||
query = db.query(Project).filter(Project.is_active == True)
|
||||
if category:
|
||||
query = query.filter(Project.category == category)
|
||||
total = query.count()
|
||||
items = query.order_by(Project.stars.desc(), Project.sort_order, Project.id.desc()).offset((page - 1) * size).limit(size).all()
|
||||
return {"total": total, "items": _with_collect_status(items, current_user.id, db)}
|
||||
|
||||
|
||||
@router.get("/latest")
|
||||
def get_latest_projects(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(12, ge=1, le=50),
|
||||
category: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""最新项目(按创建时间降序)"""
|
||||
query = db.query(Project).filter(Project.is_active == True)
|
||||
if category:
|
||||
query = query.filter(Project.category == category)
|
||||
total = query.count()
|
||||
items = query.order_by(Project.created_at.desc(), Project.id.desc()).offset((page - 1) * size).limit(size).all()
|
||||
return {"total": total, "items": _with_collect_status(items, current_user.id, db)}
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
def search_projects(
|
||||
q: str = Query("", min_length=0),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(12, ge=1, le=50),
|
||||
category: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""搜索项目"""
|
||||
query = db.query(Project).filter(Project.is_active == True)
|
||||
if q.strip():
|
||||
kw = f"%{q.strip()}%"
|
||||
query = query.filter(or_(Project.name.like(kw), Project.description.like(kw)))
|
||||
if category:
|
||||
query = query.filter(Project.category == category)
|
||||
total = query.count()
|
||||
items = query.order_by(Project.stars.desc(), Project.id.desc()).offset((page - 1) * size).limit(size).all()
|
||||
return {"total": total, "items": _with_collect_status(items, current_user.id, db)}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
def get_project_categories(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取所有有项目的分类"""
|
||||
rows = (
|
||||
db.query(Project.category, sa_func.count(Project.id))
|
||||
.filter(Project.is_active == True, Project.category != "")
|
||||
.group_by(Project.category)
|
||||
.order_by(sa_func.count(Project.id).desc())
|
||||
.all()
|
||||
)
|
||||
return [{"name": r[0], "count": r[1]} for r in rows]
|
||||
|
||||
|
||||
# ========== 收藏接口 ==========
|
||||
|
||||
@router.get("/my-collects")
|
||||
def get_my_collects(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(12, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前用户收藏的项目"""
|
||||
subq = db.query(ProjectCollect.project_id).filter(ProjectCollect.user_id == current_user.id).subquery()
|
||||
query = db.query(Project).filter(Project.id.in_(subq), Project.is_active == True)
|
||||
total = query.count()
|
||||
items = query.order_by(Project.id.desc()).offset((page - 1) * size).limit(size).all()
|
||||
return {"total": total, "items": [_project_to_dict(p, True) for p in items]}
|
||||
|
||||
|
||||
@router.post("/{project_id}/collect")
|
||||
def toggle_project_collect(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""收藏/取消收藏项目"""
|
||||
proj = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
existing = db.query(ProjectCollect).filter(
|
||||
ProjectCollect.project_id == project_id, ProjectCollect.user_id == current_user.id
|
||||
).first()
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
proj.collect_count = max(0, (proj.collect_count or 0) - 1)
|
||||
db.commit()
|
||||
return {"collected": False, "collect_count": proj.collect_count}
|
||||
else:
|
||||
db.add(ProjectCollect(project_id=project_id, user_id=current_user.id))
|
||||
proj.collect_count = (proj.collect_count or 0) + 1
|
||||
db.commit()
|
||||
return {"collected": True, "collect_count": proj.collect_count}
|
||||
|
||||
|
||||
# ========== GitHub 搜索(公共 + 管理员通用) ==========
|
||||
|
||||
async def _github_search_impl(q: str, sort: str, page: int, per_page: int):
|
||||
"""GitHub 搜索核心实现"""
|
||||
url = f"{GITHUB_API}/search/repositories"
|
||||
params = {"q": q, "sort": sort, "order": "desc", "page": page, "per_page": per_page}
|
||||
headers = {"Accept": "application/vnd.github+json"}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(url, params=params, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(status_code=502, detail=f"GitHub API 返回 {resp.status_code}")
|
||||
data = resp.json()
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="GitHub API 请求超时")
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"网络请求失败: {str(e)}")
|
||||
|
||||
items = []
|
||||
for repo in data.get("items", []):
|
||||
items.append({
|
||||
"github_id": repo["id"],
|
||||
"name": repo.get("name", ""),
|
||||
"full_name": repo.get("full_name", ""),
|
||||
"description": repo.get("description") or "",
|
||||
"url": repo.get("html_url", ""),
|
||||
"homepage": repo.get("homepage") or "",
|
||||
"icon": repo.get("owner", {}).get("avatar_url", ""),
|
||||
"language": repo.get("language") or "",
|
||||
"stars": repo.get("stargazers_count", 0),
|
||||
"forks": repo.get("forks_count", 0),
|
||||
"topics": repo.get("topics", []),
|
||||
"created_at": repo.get("created_at", ""),
|
||||
"updated_at": repo.get("updated_at", ""),
|
||||
})
|
||||
return {"total": data.get("total_count", 0), "items": items}
|
||||
|
||||
|
||||
@router.get("/github-search")
|
||||
async def public_github_search(
|
||||
q: str = Query(..., min_length=1),
|
||||
sort: str = Query("stars"),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(12, ge=1, le=30),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""公开 GitHub 搜索(登录用户可用)"""
|
||||
return await _github_search_impl(q, sort, page, per_page)
|
||||
|
||||
|
||||
# ========== GitHub 导入接口(管理员) ==========
|
||||
|
||||
@router.get("/admin/github-search")
|
||||
async def github_search(
|
||||
q: str = Query(..., min_length=1),
|
||||
sort: str = Query("stars"),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(12, ge=1, le=30),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""管理员 GitHub 搜索"""
|
||||
return await _github_search_impl(q, sort, page, per_page)
|
||||
|
||||
|
||||
class GitHubImportItem(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
url: str
|
||||
homepage: str = ""
|
||||
icon: str = ""
|
||||
language: str = ""
|
||||
category: str = ""
|
||||
stars: int = 0
|
||||
forks: int = 0
|
||||
|
||||
|
||||
class GitHubImportRequest(BaseModel):
|
||||
items: List[GitHubImportItem]
|
||||
|
||||
|
||||
@router.post("/admin/github-import")
|
||||
def github_import(
|
||||
data: GitHubImportRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""批量导入 GitHub 项目"""
|
||||
imported = 0
|
||||
skipped = 0
|
||||
max_order = db.query(sa_func.max(Project.sort_order)).scalar() or 0
|
||||
for item in data.items:
|
||||
existing = db.query(Project).filter(Project.url == item.url).first()
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
max_order += 1
|
||||
proj = Project(
|
||||
name=item.name,
|
||||
description=item.description,
|
||||
url=item.url,
|
||||
homepage=item.homepage,
|
||||
icon=item.icon,
|
||||
language=item.language,
|
||||
category=item.category,
|
||||
stars=item.stars,
|
||||
forks=item.forks,
|
||||
sort_order=max_order,
|
||||
)
|
||||
db.add(proj)
|
||||
imported += 1
|
||||
db.commit()
|
||||
return {"imported": imported, "skipped": skipped}
|
||||
|
||||
|
||||
# ========== 项目详情(通配路由放最后) ==========
|
||||
|
||||
@router.get("/{project_id}")
|
||||
def get_project_detail(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""项目详情"""
|
||||
proj = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
return _project_to_dict(proj)
|
||||
313
backend/routers/requirement.py
Normal file
313
backend/routers/requirement.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""需求理解助手路由"""
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.conversation import Conversation, Message
|
||||
from schemas.conversation import (
|
||||
RequirementAnalyzeRequest, ConversationResponse,
|
||||
ConversationDetail, MessageResponse,
|
||||
)
|
||||
from routers.auth import get_current_user
|
||||
from services.ai_service import ai_service
|
||||
from config import UPLOAD_DIR
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
REQUIREMENT_SYSTEM_PROMPT = """# 角色定义
|
||||
你同时具备两个身份:
|
||||
1. **资深产品经理(10年+)**:擅长从模糊信息中提炼需求本质,精通用户故事、MECE拆解、优先级排序、验收标准制定
|
||||
2. **高级全栈程序员(10年+)**:做过大量项目,对功能的复杂度、开发工作量有精准直觉,能一眼看出哪些需求是"看似简单实则巨坑"
|
||||
|
||||
你服务的对象是**程序员**,他们会把甲方发来的原始内容(口语化描述、语音转文字、截图、聊天记录、需求文档等)发给你。你需要站在「既懂业务又懂技术」的双重视角,将其转化为**清晰明确、可直接进入开发的结构化需求**。
|
||||
|
||||
> ⚠️ 本助手专注于**需求理解与分析**。技术选型、数据库设计、API设计、架构方案等请移步「架构选型AI助手」。
|
||||
|
||||
# 核心理念
|
||||
- **需求层面**:甲方说的不一定是他真正想要的,你要透过表面描述挖掘真实诉求
|
||||
- **程序员直觉**:每个功能你都要心里过一遍复杂度,标注哪些"看起来简单但实际很复杂"
|
||||
- **落地导向**:不出空中楼阁式的分析,每条功能都要能对应到具体的开发任务
|
||||
|
||||
# 分析框架
|
||||
收到用户输入后,按以下框架进行系统性分析:
|
||||
|
||||
## 第一步:需求还原(产品经理视角)
|
||||
- 理解甲方的**核心意图**:到底想解决什么问题?服务什么业务场景?
|
||||
- 识别**目标用户**是谁,使用频率、核心使用路径是什么
|
||||
- 区分"真实需求"与"表面描述",找出甲方没说但一定需要的**隐性需求**
|
||||
- 判断产品定位:工具型/平台型/内容型?ToB/ToC?
|
||||
|
||||
## 第二步:功能拆解(遵循 MECE 原则)
|
||||
将需求拆解为相互独立、完全穷尽的功能模块,每个功能需包含:
|
||||
- 功能名称 + 具体说明
|
||||
- 优先级(P0 核心必做 / P1 重要 / P2 锦上添花)
|
||||
- 涉及的用户角色
|
||||
- 复杂度预判(简单/中等/复杂)+ 复杂原因说明
|
||||
|
||||
## 第三步:用户故事与验收标准
|
||||
将核心功能转写为标准用户故事:
|
||||
> 作为【角色】,我希望【功能】,以便【价值/目的】
|
||||
|
||||
为 P0 功能补充验收标准(AC),采用 Given-When-Then 格式:
|
||||
> 假设【前置条件】,当【用户操作】,那么【系统响应】
|
||||
|
||||
## 第四步:复杂度预警(程序员视角)
|
||||
基于你做过大量项目的经验,标注:
|
||||
- **隐藏复杂度**:哪些功能看似简单但实现起来有坑(如"支持实时聊天"看似一句话,实际涉及WebSocket、消息队列、已读回执等)
|
||||
- **高风险功能**:涉及支付、权限、数据一致性等需要特别慎重的功能
|
||||
- **工期杀手**:容易严重超出预期工时的功能,提前预警
|
||||
- **工期粗估**:按模块给出大致工时范围(x-x天),帮助程序员评估排期
|
||||
|
||||
## 第五步:边界与待确认项
|
||||
- 需求中**含糊不清**需要甲方确认的关键问题
|
||||
- 容易遗漏的**边缘场景**(空状态、异常流、权限边界、数据量极值、并发操作)
|
||||
- 甲方可能还没想到但**一定会追加**的需求(基于经验预判)
|
||||
|
||||
# 输出规范
|
||||
严格使用以下 Markdown 结构输出:
|
||||
|
||||
---
|
||||
|
||||
## 📋 需求概述
|
||||
> 用2-3句话概括:这是什么产品/功能,解决谁的什么问题,核心价值是什么。
|
||||
> 产品定位:【工具型/平台型/内容型】 | 【ToB/ToC】
|
||||
|
||||
## 👥 用户角色
|
||||
| 角色 | 说明 | 关键诉求 | 使用频率 |
|
||||
|------|------|---------|---------|
|
||||
|
||||
## 🧩 功能清单
|
||||
| 优先级 | 模块 | 功能项 | 功能说明 | 涉及角色 | 复杂度 |
|
||||
|--------|------|--------|---------|---------|--------|
|
||||
| P0 | xxx | xxx | xxx | xxx | 🔴复杂 - 原因 |
|
||||
| P0 | xxx | xxx | xxx | xxx | 🟢简单 |
|
||||
| P1 | xxx | xxx | xxx | xxx | 🟡中等 |
|
||||
|
||||
## 📖 核心用户故事与验收标准
|
||||
### US-1: 【用户故事标题】
|
||||
- **故事**:作为【角色】,我希望【功能】,以便【价值】
|
||||
- **验收标准**:
|
||||
- ✅ 假设【条件】,当【操作】,那么【预期结果】
|
||||
- ✅ 假设【条件】,当【异常操作】,那么【兜底处理】
|
||||
|
||||
## ⚡ 复杂度预警(程序员必看)
|
||||
### 🔴 隐藏复杂度
|
||||
1. **【功能名】**:看似xxx,实际需要xxx,建议xxx
|
||||
|
||||
### ⏱️ 工期粗估
|
||||
| 模块 | 工时范围 | 说明 |
|
||||
|------|---------|------|
|
||||
| 合计 | x-x天 | 基于1名全栈开发者,含开发+自测 |
|
||||
|
||||
### 🔮 甲方大概率会追加的需求
|
||||
1. 【需求名】— 理由:基于经验,做了xxx之后甲方通常会要求xxx
|
||||
|
||||
## ❓ 待确认问题
|
||||
> 以下问题会直接影响开发方案,建议优先与甲方确认:
|
||||
1. **【问题】**:【为什么需要确认】→ 不确认的影响:【xxx】
|
||||
|
||||
## 💡 风险提示
|
||||
- 【业务风险、边缘场景、容易遗漏的点等】
|
||||
|
||||
---
|
||||
|
||||
> 💡 **下一步建议**:需求确认后,可将功能清单发送至「架构选型AI助手」,获取技术选型、数据库设计、API接口设计等技术方案。
|
||||
|
||||
# 交互原则
|
||||
1. **首次分析要全面**:即使信息不完整,也要基于已有信息给出最完整的分析,用待确认问题标注不确定的部分
|
||||
2. **复杂度标注要诚实**:程序员最怕"这个很简单",要如实标注复杂度和潜在的坑
|
||||
3. **追问要有价值**:只追问影响开发方案的关键问题,不问显而易见的
|
||||
4. **语言贴近程序员**:用开发者能直接理解的术语,避免纯商业话术
|
||||
5. **持续迭代**:用户补充信息后,在之前分析的基础上更新完善,而非重头开始
|
||||
6. **务实不空谈**:每条建议都基于实际经验,不说"建议优化用户体验"这类空话
|
||||
7. **工时要诚实**:给出合理区间,标注不确定因素"""
|
||||
|
||||
|
||||
@router.post("/upload-image")
|
||||
async def upload_image(
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""上传图片"""
|
||||
# 验证文件类型
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||||
|
||||
# 生成唯一文件名
|
||||
ext = file.filename.split(".")[-1] if "." in file.filename else "png"
|
||||
filename = f"{uuid.uuid4().hex}.{ext}"
|
||||
filepath = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# 保存文件
|
||||
content = await file.read()
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
return {"url": f"/uploads/{filename}"}
|
||||
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||
def get_conversations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取需求对话列表"""
|
||||
conversations = (
|
||||
db.query(Conversation)
|
||||
.filter(Conversation.user_id == current_user.id, Conversation.type == "requirement")
|
||||
.order_by(Conversation.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [ConversationResponse.model_validate(c) for c in conversations]
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
|
||||
def get_conversation_detail(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取对话详情"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
messages = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
result = ConversationDetail.model_validate(conv)
|
||||
result.messages = [MessageResponse.model_validate(m) for m in messages]
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze_requirement(
|
||||
request: RequirementAnalyzeRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""分析需求 - 流式输出"""
|
||||
# 创建或获取对话
|
||||
if request.conversation_id:
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == request.conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
else:
|
||||
conv = Conversation(
|
||||
user_id=current_user.id,
|
||||
title=request.content[:50] if request.content else "新需求分析",
|
||||
type="requirement",
|
||||
)
|
||||
db.add(conv)
|
||||
db.commit()
|
||||
db.refresh(conv)
|
||||
|
||||
# 保存用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.content,
|
||||
image_urls=json.dumps(request.image_urls) if request.image_urls else "",
|
||||
)
|
||||
db.add(user_msg)
|
||||
db.commit()
|
||||
|
||||
# 构建历史消息
|
||||
history_msgs = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
messages = []
|
||||
for msg in history_msgs:
|
||||
if msg.role == "user":
|
||||
content = msg.content
|
||||
# 如果有图片,添加图片描述提示
|
||||
if msg.image_urls:
|
||||
try:
|
||||
urls = json.loads(msg.image_urls)
|
||||
if urls:
|
||||
content += f"\n\n[用户上传了{len(urls)}张图片]"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
messages.append({"role": "assistant", "content": msg.content})
|
||||
|
||||
# 确定任务类型
|
||||
task_type = "multimodal" if request.image_urls else "reasoning"
|
||||
|
||||
# 流式调用AI
|
||||
async def generate():
|
||||
full_response = ""
|
||||
try:
|
||||
result = await ai_service.chat(
|
||||
task_type=task_type,
|
||||
messages=messages,
|
||||
system_prompt=REQUIREMENT_SYSTEM_PROMPT,
|
||||
stream=True,
|
||||
model_config_id=request.model_config_id,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
# 非流式返回
|
||||
full_response = result
|
||||
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
|
||||
else:
|
||||
async for chunk in result:
|
||||
full_response += chunk
|
||||
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
|
||||
except Exception as e:
|
||||
error_msg = f"AI调用出错: {str(e)}"
|
||||
full_response = error_msg
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
|
||||
# 保存AI回复
|
||||
ai_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="assistant",
|
||||
content=full_response,
|
||||
)
|
||||
db.add(ai_msg)
|
||||
db.commit()
|
||||
|
||||
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
def delete_conversation(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除对话"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
|
||||
db.delete(conv)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
47
backend/routers/search.py
Normal file
47
backend/routers/search.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""搜索路由"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.post import Post
|
||||
from schemas.post import PostResponse, PostListResponse
|
||||
from routers.auth import get_current_user
|
||||
from routers.posts import _enrich_post
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=PostListResponse)
|
||||
def search_posts(
|
||||
q: str = Query(..., min_length=1, description="搜索关键词"),
|
||||
category: Optional[str] = None,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""全文搜索帖子"""
|
||||
query = db.query(Post).filter(
|
||||
or_(Post.is_public == True, Post.user_id == current_user.id),
|
||||
or_(
|
||||
Post.title.contains(q),
|
||||
Post.content.contains(q),
|
||||
Post.tags.contains(q),
|
||||
),
|
||||
)
|
||||
|
||||
if category:
|
||||
query = query.filter(Post.category == category)
|
||||
|
||||
total = query.count()
|
||||
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
return PostListResponse(
|
||||
items=[_enrich_post(p, db, current_user.id) for p in posts],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
526
backend/routers/shared_api.py
Normal file
526
backend/routers/shared_api.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""共享API Hub路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func as sa_func
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from database import get_db
|
||||
from config import SECRET_KEY
|
||||
from models.user import User
|
||||
from models.system_config import SystemConfig
|
||||
from models.shared_api import SharedApiCategory, SharedApi, SharedApiLog
|
||||
from routers.auth import get_current_user, get_admin_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ========== 加密工具 ==========
|
||||
|
||||
_fernet = None
|
||||
|
||||
def _get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
from cryptography.fernet import Fernet
|
||||
import base64
|
||||
# 从SECRET_KEY派生一个Fernet兼容的key
|
||||
key = hashlib.sha256(SECRET_KEY.encode()).digest()
|
||||
_fernet = Fernet(base64.urlsafe_b64encode(key))
|
||||
return _fernet
|
||||
|
||||
def encrypt_key(plain: str) -> str:
|
||||
if not plain:
|
||||
return ""
|
||||
return _get_fernet().encrypt(plain.encode()).decode()
|
||||
|
||||
def decrypt_key(encrypted: str) -> str:
|
||||
if not encrypted:
|
||||
return ""
|
||||
try:
|
||||
return _get_fernet().decrypt(encrypted.encode()).decode()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def mask_key(encrypted: str) -> str:
|
||||
"""脱敏显示"""
|
||||
plain = decrypt_key(encrypted)
|
||||
if not plain:
|
||||
return ""
|
||||
if len(plain) <= 8:
|
||||
return plain[:2] + "***"
|
||||
return plain[:4] + "****" + plain[-4:]
|
||||
|
||||
|
||||
# ========== Hub访问密码机制 ==========
|
||||
|
||||
def _get_hub_password(db: Session) -> str:
|
||||
cfg = db.query(SystemConfig).filter(SystemConfig.key == "api_hub_password").first()
|
||||
return cfg.value if cfg else ""
|
||||
|
||||
def _hub_password_version(db: Session) -> str:
|
||||
"""返回密码哈希前8位作为版本标识,密码变更后旧token自动失效"""
|
||||
pwd = _get_hub_password(db)
|
||||
return pwd[:8] if pwd else "none"
|
||||
|
||||
def _hash_password(pwd: str) -> str:
|
||||
return hashlib.sha256(pwd.encode()).hexdigest()
|
||||
|
||||
def _create_hub_token(user_id: int, pwd_ver: str = "none") -> str:
|
||||
"""创建Hub访问令牌(简单签名,2小时有效)"""
|
||||
from jose import jwt
|
||||
exp = datetime.utcnow() + timedelta(hours=2)
|
||||
return jwt.encode({"sub": str(user_id), "hub": True, "pv": pwd_ver, "exp": exp}, SECRET_KEY, algorithm="HS256")
|
||||
|
||||
def verify_hub_access(
|
||||
x_hub_token: Optional[str] = Header(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""验证用户登录 + Hub访问令牌"""
|
||||
if not x_hub_token:
|
||||
raise HTTPException(status_code=403, detail="需要API Hub访问权限,请先验证密码")
|
||||
from jose import jwt, JWTError
|
||||
try:
|
||||
payload = jwt.decode(x_hub_token, SECRET_KEY, algorithms=["HS256"])
|
||||
if not payload.get("hub"):
|
||||
raise HTTPException(status_code=403, detail="无效的Hub令牌")
|
||||
# 检查密码版本是否匹配
|
||||
token_pv = payload.get("pv", "")
|
||||
current_pv = _hub_password_version(db)
|
||||
if token_pv != current_pv:
|
||||
raise HTTPException(status_code=403, detail="密码已变更,请重新验证")
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Hub令牌已过期,请重新验证密码")
|
||||
return current_user
|
||||
|
||||
|
||||
# ========== Schemas ==========
|
||||
|
||||
class HubAuthRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
class CategoryCreate(BaseModel):
|
||||
name: str
|
||||
icon: str = ""
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class ApiCreate(BaseModel):
|
||||
category_id: Optional[int] = None
|
||||
name: str
|
||||
description: str = ""
|
||||
base_url: str = ""
|
||||
doc_url: str = ""
|
||||
auth_type: str = "none"
|
||||
api_key: str = "" # 明文传入,后端加密存储
|
||||
api_key_header: str = "Authorization"
|
||||
health_check_url: str = ""
|
||||
tags: str = ""
|
||||
|
||||
class ApiUpdate(BaseModel):
|
||||
category_id: Optional[int] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
doc_url: Optional[str] = None
|
||||
auth_type: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_key_header: Optional[str] = None
|
||||
health_check_url: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class ApiTestRequest(BaseModel):
|
||||
method: str = "GET"
|
||||
path: str = ""
|
||||
body: str = ""
|
||||
headers: dict = {}
|
||||
|
||||
|
||||
# ========== 密码认证接口 ==========
|
||||
|
||||
@router.post("/auth")
|
||||
def hub_auth(
|
||||
data: HubAuthRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证Hub访问密码"""
|
||||
stored = _get_hub_password(db)
|
||||
if not stored:
|
||||
raise HTTPException(status_code=400, detail="管理员尚未设置访问密码")
|
||||
if _hash_password(data.password) != stored:
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
token = _create_hub_token(user.id, _hub_password_version(db))
|
||||
return {"hub_token": token, "expires_in": 7200}
|
||||
|
||||
@router.get("/check-password")
|
||||
def check_password_set(
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""检查是否已设置访问密码"""
|
||||
stored = _get_hub_password(db)
|
||||
return {"has_password": bool(stored)}
|
||||
|
||||
@router.put("/admin/password")
|
||||
def set_hub_password(
|
||||
data: HubAuthRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""管理员设置Hub访问密码"""
|
||||
if len(data.password) < 4:
|
||||
raise HTTPException(status_code=400, detail="密码至少4位")
|
||||
hashed = _hash_password(data.password)
|
||||
cfg = db.query(SystemConfig).filter(SystemConfig.key == "api_hub_password").first()
|
||||
if cfg:
|
||||
cfg.value = hashed
|
||||
else:
|
||||
cfg = SystemConfig(key="api_hub_password", value=hashed, description="API Hub访问密码")
|
||||
db.add(cfg)
|
||||
db.commit()
|
||||
return {"message": "密码设置成功"}
|
||||
|
||||
@router.get("/admin/password-status")
|
||||
def get_password_status(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user),
|
||||
):
|
||||
"""管理员查看密码是否已设置"""
|
||||
stored = _get_hub_password(db)
|
||||
return {"has_password": bool(stored)}
|
||||
|
||||
|
||||
# ========== 分类管理 ==========
|
||||
|
||||
@router.get("/categories")
|
||||
def list_categories(
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
cats = db.query(SharedApiCategory).order_by(SharedApiCategory.sort_order, SharedApiCategory.id).all()
|
||||
return [
|
||||
{"id": c.id, "name": c.name, "icon": c.icon, "sort_order": c.sort_order, "is_active": c.is_active,
|
||||
"api_count": db.query(sa_func.count(SharedApi.id)).filter(SharedApi.category_id == c.id).scalar() or 0}
|
||||
for c in cats
|
||||
]
|
||||
|
||||
@router.post("/categories")
|
||||
def create_category(
|
||||
data: CategoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
existing = db.query(SharedApiCategory).filter(SharedApiCategory.name == data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||||
max_order = db.query(sa_func.max(SharedApiCategory.sort_order)).scalar() or 0
|
||||
cat = SharedApiCategory(name=data.name, icon=data.icon, sort_order=max_order + 1)
|
||||
db.add(cat)
|
||||
db.commit()
|
||||
db.refresh(cat)
|
||||
return {"id": cat.id, "name": cat.name, "icon": cat.icon}
|
||||
|
||||
@router.put("/categories/{cat_id}")
|
||||
def update_category(
|
||||
cat_id: int,
|
||||
data: CategoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
cat = db.query(SharedApiCategory).filter(SharedApiCategory.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
if data.name is not None:
|
||||
cat.name = data.name
|
||||
if data.icon is not None:
|
||||
cat.icon = data.icon
|
||||
if data.sort_order is not None:
|
||||
cat.sort_order = data.sort_order
|
||||
if data.is_active is not None:
|
||||
cat.is_active = data.is_active
|
||||
db.commit()
|
||||
return {"message": "更新成功"}
|
||||
|
||||
@router.delete("/categories/{cat_id}")
|
||||
def delete_category(
|
||||
cat_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
cat = db.query(SharedApiCategory).filter(SharedApiCategory.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
# 将该分类下的API设为未分类
|
||||
db.query(SharedApi).filter(SharedApi.category_id == cat_id).update({SharedApi.category_id: None})
|
||||
db.delete(cat)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
# ========== API CRUD ==========
|
||||
|
||||
def _api_to_dict(api, db=None):
|
||||
d = {
|
||||
"id": api.id, "category_id": api.category_id,
|
||||
"name": api.name, "description": api.description,
|
||||
"base_url": api.base_url, "doc_url": api.doc_url,
|
||||
"auth_type": api.auth_type,
|
||||
"api_key_masked": mask_key(api.api_key_encrypted),
|
||||
"api_key_plain": decrypt_key(api.api_key_encrypted) if api.api_key_encrypted else "",
|
||||
"has_api_key": bool(api.api_key_encrypted),
|
||||
"api_key_header": api.api_key_header,
|
||||
"health_check_url": api.health_check_url,
|
||||
"last_check_time": api.last_check_time.isoformat() if api.last_check_time else None,
|
||||
"last_check_status": api.last_check_status,
|
||||
"added_by": api.added_by, "tags": api.tags,
|
||||
"call_count": api.call_count, "is_active": api.is_active,
|
||||
"created_at": api.created_at.isoformat() if api.created_at else None,
|
||||
"updated_at": api.updated_at.isoformat() if api.updated_at else None,
|
||||
}
|
||||
return d
|
||||
|
||||
@router.get("/list")
|
||||
def list_apis(
|
||||
keyword: Optional[str] = None,
|
||||
category_id: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
query = db.query(SharedApi).filter(SharedApi.is_active == True)
|
||||
if keyword:
|
||||
kw = f"%{keyword}%"
|
||||
query = query.filter(
|
||||
(SharedApi.name.like(kw)) | (SharedApi.description.like(kw)) | (SharedApi.tags.like(kw))
|
||||
)
|
||||
if category_id is not None:
|
||||
query = query.filter(SharedApi.category_id == category_id)
|
||||
apis = query.order_by(SharedApi.call_count.desc(), SharedApi.id.desc()).all()
|
||||
return {"items": [_api_to_dict(a) for a in apis], "total": len(apis)}
|
||||
|
||||
@router.post("/")
|
||||
def create_api(
|
||||
data: ApiCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
api = SharedApi(
|
||||
category_id=data.category_id, name=data.name, description=data.description,
|
||||
base_url=data.base_url, doc_url=data.doc_url,
|
||||
auth_type=data.auth_type,
|
||||
api_key_encrypted=encrypt_key(data.api_key) if data.api_key else "",
|
||||
api_key_header=data.api_key_header,
|
||||
health_check_url=data.health_check_url,
|
||||
tags=data.tags, added_by=user.id,
|
||||
)
|
||||
db.add(api)
|
||||
db.commit()
|
||||
db.refresh(api)
|
||||
return _api_to_dict(api)
|
||||
|
||||
@router.put("/{api_id}")
|
||||
def update_api(
|
||||
api_id: int,
|
||||
data: ApiUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||||
if not api:
|
||||
raise HTTPException(status_code=404, detail="API不存在")
|
||||
if data.category_id is not None:
|
||||
api.category_id = data.category_id
|
||||
if data.name is not None:
|
||||
api.name = data.name
|
||||
if data.description is not None:
|
||||
api.description = data.description
|
||||
if data.base_url is not None:
|
||||
api.base_url = data.base_url
|
||||
if data.doc_url is not None:
|
||||
api.doc_url = data.doc_url
|
||||
if data.auth_type is not None:
|
||||
api.auth_type = data.auth_type
|
||||
if data.api_key is not None and data.api_key != "":
|
||||
api.api_key_encrypted = encrypt_key(data.api_key)
|
||||
if data.api_key_header is not None:
|
||||
api.api_key_header = data.api_key_header
|
||||
if data.health_check_url is not None:
|
||||
api.health_check_url = data.health_check_url
|
||||
if data.tags is not None:
|
||||
api.tags = data.tags
|
||||
if data.is_active is not None:
|
||||
api.is_active = data.is_active
|
||||
db.commit()
|
||||
db.refresh(api)
|
||||
return _api_to_dict(api)
|
||||
|
||||
@router.delete("/{api_id}")
|
||||
def delete_api(
|
||||
api_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||||
if not api:
|
||||
raise HTTPException(status_code=404, detail="API不存在")
|
||||
db.query(SharedApiLog).filter(SharedApiLog.api_id == api_id).delete()
|
||||
db.delete(api)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
# ========== API测试 ==========
|
||||
|
||||
@router.post("/{api_id}/test")
|
||||
async def test_api(
|
||||
api_id: int,
|
||||
data: ApiTestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
"""在线测试API(后端代理请求)"""
|
||||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||||
if not api:
|
||||
raise HTTPException(status_code=404, detail="API不存在")
|
||||
|
||||
url = api.base_url.rstrip("/")
|
||||
if data.path:
|
||||
url = url + "/" + data.path.lstrip("/")
|
||||
|
||||
headers = dict(data.headers) if data.headers else {}
|
||||
# 注入认证信息
|
||||
if api.auth_type != "none" and api.api_key_encrypted:
|
||||
key = decrypt_key(api.api_key_encrypted)
|
||||
if key:
|
||||
if api.auth_type == "bearer":
|
||||
headers[api.api_key_header] = f"Bearer {key}"
|
||||
elif api.auth_type == "api_key":
|
||||
headers[api.api_key_header] = key
|
||||
elif api.auth_type == "basic":
|
||||
import base64
|
||||
headers["Authorization"] = f"Basic {base64.b64encode(key.encode()).decode()}"
|
||||
|
||||
import httpx
|
||||
start = time.time()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
if data.method.upper() == "POST":
|
||||
resp = await client.post(url, headers=headers, content=data.body or None)
|
||||
elif data.method.upper() == "PUT":
|
||||
resp = await client.put(url, headers=headers, content=data.body or None)
|
||||
elif data.method.upper() == "DELETE":
|
||||
resp = await client.delete(url, headers=headers)
|
||||
else:
|
||||
resp = await client.get(url, headers=headers)
|
||||
elapsed = int((time.time() - start) * 1000)
|
||||
# 记录日志
|
||||
log = SharedApiLog(
|
||||
api_id=api_id, user_id=user.id, action="test",
|
||||
request_url=url, response_status=resp.status_code, response_time_ms=elapsed,
|
||||
)
|
||||
db.add(log)
|
||||
api.call_count = (api.call_count or 0) + 1
|
||||
db.commit()
|
||||
# 限制返回体大小
|
||||
body = resp.text[:5000] if len(resp.text) > 5000 else resp.text
|
||||
return {
|
||||
"status_code": resp.status_code,
|
||||
"response_time_ms": elapsed,
|
||||
"headers": dict(resp.headers),
|
||||
"body": body,
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed = int((time.time() - start) * 1000)
|
||||
log = SharedApiLog(
|
||||
api_id=api_id, user_id=user.id, action="test",
|
||||
request_url=url, response_status=0, response_time_ms=elapsed,
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
return {"status_code": 0, "response_time_ms": elapsed, "headers": {}, "body": str(e)}
|
||||
|
||||
|
||||
@router.post("/{api_id}/health-check")
|
||||
async def health_check(
|
||||
api_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
"""健康检查"""
|
||||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||||
if not api:
|
||||
raise HTTPException(status_code=404, detail="API不存在")
|
||||
|
||||
check_url = api.health_check_url or api.base_url
|
||||
import httpx
|
||||
start = time.time()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(check_url)
|
||||
elapsed = int((time.time() - start) * 1000)
|
||||
status = "ok" if resp.status_code < 400 else "error"
|
||||
except Exception:
|
||||
elapsed = int((time.time() - start) * 1000)
|
||||
status = "error"
|
||||
|
||||
api.last_check_time = datetime.utcnow()
|
||||
api.last_check_status = status
|
||||
log = SharedApiLog(
|
||||
api_id=api_id, user_id=user.id, action="health_check",
|
||||
request_url=check_url, response_status=resp.status_code if status == "ok" else 0,
|
||||
response_time_ms=elapsed,
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
return {"status": status, "response_time_ms": elapsed}
|
||||
|
||||
|
||||
# ========== 日志与统计 ==========
|
||||
|
||||
@router.get("/{api_id}/logs")
|
||||
def get_api_logs(
|
||||
api_id: int,
|
||||
limit: int = 20,
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
logs = (
|
||||
db.query(SharedApiLog)
|
||||
.filter(SharedApiLog.api_id == api_id)
|
||||
.order_by(SharedApiLog.id.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": l.id, "action": l.action, "request_url": l.request_url,
|
||||
"response_status": l.response_status, "response_time_ms": l.response_time_ms,
|
||||
"user_id": l.user_id,
|
||||
"created_at": l.created_at.isoformat() if l.created_at else None,
|
||||
}
|
||||
for l in logs
|
||||
]
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats(
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(verify_hub_access),
|
||||
):
|
||||
total_apis = db.query(sa_func.count(SharedApi.id)).filter(SharedApi.is_active == True).scalar() or 0
|
||||
total_calls = db.query(sa_func.sum(SharedApi.call_count)).scalar() or 0
|
||||
total_categories = db.query(sa_func.count(SharedApiCategory.id)).scalar() or 0
|
||||
healthy = db.query(sa_func.count(SharedApi.id)).filter(SharedApi.last_check_status == "ok").scalar() or 0
|
||||
return {
|
||||
"total_apis": total_apis,
|
||||
"total_calls": total_calls,
|
||||
"total_categories": total_categories,
|
||||
"healthy_count": healthy,
|
||||
}
|
||||
206
backend/routers/upload.py
Normal file
206
backend/routers/upload.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""通用文件上传路由 - 腾讯云COS"""
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, UploadFile, File, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from routers.auth import get_current_user
|
||||
from models.user import User
|
||||
from models.system_config import SystemConfig
|
||||
from database import get_db
|
||||
from config import MAX_UPLOAD_SIZE, MAX_ATTACHMENT_SIZE, UPLOAD_DIR
|
||||
from models.attachment import Attachment
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
|
||||
ALLOWED_ATTACHMENT_TYPES = [
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/zip",
|
||||
"application/x-zip-compressed",
|
||||
"application/x-rar-compressed",
|
||||
"application/vnd.rar",
|
||||
]
|
||||
ALLOWED_ATTACHMENT_EXTS = {
|
||||
"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "zip", "rar",
|
||||
}
|
||||
|
||||
|
||||
def _get_cos_config(db: Session) -> dict:
|
||||
"""从数据库读取COS配置"""
|
||||
keys = ["cos_secret_id", "cos_secret_key", "cos_bucket", "cos_region", "cos_custom_domain"]
|
||||
config = {}
|
||||
for k in keys:
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == k).first()
|
||||
config[k] = row.value if row else ""
|
||||
return config
|
||||
|
||||
|
||||
def _get_cos_client(db: Session):
|
||||
"""获取COS客户端实例,未配置则返回None"""
|
||||
config = _get_cos_config(db)
|
||||
secret_id = config.get("cos_secret_id", "")
|
||||
secret_key = config.get("cos_secret_key", "")
|
||||
bucket = config.get("cos_bucket", "")
|
||||
region = config.get("cos_region", "")
|
||||
if not all([secret_id, secret_key, bucket, region]):
|
||||
return None, config
|
||||
try:
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
cos_config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
|
||||
client = CosS3Client(cos_config)
|
||||
return client, config
|
||||
except ImportError:
|
||||
return None, config
|
||||
|
||||
|
||||
def _build_cos_url(config: dict, object_key: str) -> str:
|
||||
"""构建COS访问URL"""
|
||||
custom_domain = config.get("cos_custom_domain", "")
|
||||
if custom_domain:
|
||||
return f"https://{custom_domain}/{object_key}"
|
||||
bucket = config.get("cos_bucket", "")
|
||||
region = config.get("cos_region", "")
|
||||
return f"https://{bucket}.cos.{region}.myqcloud.com/{object_key}"
|
||||
|
||||
|
||||
@router.post("/image")
|
||||
async def upload_image(
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""上传图片,优先OSS,未配置则本地存储"""
|
||||
# 验证文件类型
|
||||
if file.content_type not in ALLOWED_IMAGE_TYPES:
|
||||
raise HTTPException(status_code=400, detail="不支持的文件类型,仅支持 JPG/PNG/GIF/WEBP")
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 验证文件大小
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(status_code=400, detail=f"文件过大,最大支持 {MAX_UPLOAD_SIZE // (1024*1024)}MB")
|
||||
|
||||
# 生成唯一文件名
|
||||
ext = file.filename.split(".")[-1].lower() if "." in file.filename else "png"
|
||||
date_prefix = datetime.now().strftime("%Y/%m")
|
||||
filename = f"{uuid.uuid4().hex}.{ext}"
|
||||
object_key = f"images/{date_prefix}/{filename}"
|
||||
|
||||
# 尝试COS上传(从数据库读取配置)
|
||||
client, config = _get_cos_client(db)
|
||||
if client:
|
||||
try:
|
||||
client.put_object(
|
||||
Bucket=config.get("cos_bucket", ""),
|
||||
Body=content,
|
||||
Key=object_key,
|
||||
ContentType=file.content_type,
|
||||
)
|
||||
url = _build_cos_url(config, object_key)
|
||||
return {"url": url, "storage": "cos"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"COS上传失败: {str(e)}")
|
||||
else:
|
||||
# 降级到本地存储
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
filepath = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(content)
|
||||
return {"url": f"/uploads/{filename}", "storage": "local"}
|
||||
|
||||
|
||||
@router.post("/attachment")
|
||||
async def upload_attachment(
|
||||
file: UploadFile = File(...),
|
||||
post_id: int = 0,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""上传附件到COS,支持 PDF/Word/Excel/PPT/ZIP/RAR"""
|
||||
# 验证文件扩展名
|
||||
ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if ext not in ALLOWED_ATTACHMENT_EXTS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型,仅支持: {', '.join(sorted(ALLOWED_ATTACHMENT_EXTS))}",
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
if len(content) > MAX_ATTACHMENT_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大,最大支持 {MAX_ATTACHMENT_SIZE // (1024*1024)}MB",
|
||||
)
|
||||
|
||||
# 生成唯一文件名
|
||||
date_prefix = datetime.now().strftime("%Y/%m")
|
||||
unique_name = f"{uuid.uuid4().hex}.{ext}"
|
||||
object_key = f"attachments/{date_prefix}/{unique_name}"
|
||||
|
||||
# 上传到COS
|
||||
client, config = _get_cos_client(db)
|
||||
if not client:
|
||||
raise HTTPException(status_code=500, detail="对象存储未配置,无法上传附件")
|
||||
|
||||
try:
|
||||
client.put_object(
|
||||
Bucket=config.get("cos_bucket", ""),
|
||||
Body=content,
|
||||
Key=object_key,
|
||||
ContentType=file.content_type or "application/octet-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"COS上传失败: {str(e)}")
|
||||
|
||||
url = _build_cos_url(config, object_key)
|
||||
|
||||
# 写入数据库
|
||||
attachment = Attachment(
|
||||
post_id=post_id if post_id else None,
|
||||
user_id=current_user.id,
|
||||
filename=file.filename,
|
||||
storage_key=object_key,
|
||||
url=url,
|
||||
file_size=len(content),
|
||||
file_type=file.content_type or "application/octet-stream",
|
||||
)
|
||||
db.add(attachment)
|
||||
db.commit()
|
||||
db.refresh(attachment)
|
||||
|
||||
return {
|
||||
"id": attachment.id,
|
||||
"filename": attachment.filename,
|
||||
"url": attachment.url,
|
||||
"file_size": attachment.file_size,
|
||||
"file_type": attachment.file_type,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/attachment/{attachment_id}/post")
|
||||
async def update_attachment_post(
|
||||
attachment_id: int,
|
||||
post_id: int = 0,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""将附件关联到帖子(新建帖子后回填 post_id)"""
|
||||
attachment = db.query(Attachment).filter(
|
||||
Attachment.id == attachment_id,
|
||||
Attachment.user_id == current_user.id,
|
||||
).first()
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="附件不存在")
|
||||
attachment.post_id = post_id
|
||||
db.commit()
|
||||
return {"message": "ok"}
|
||||
214
backend/routers/users.py
Normal file
214
backend/routers/users.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""用户主页和关注系统路由"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.post import Post
|
||||
from models.follow import Follow
|
||||
from models.like import Collect
|
||||
from models.notification import Notification
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
def get_user_profile(
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户主页信息"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
post_count = db.query(func.count(Post.id)).filter(Post.user_id == user_id, Post.is_public == True).scalar()
|
||||
follower_count = db.query(func.count(Follow.id)).filter(Follow.following_id == user_id).scalar()
|
||||
following_count = db.query(func.count(Follow.id)).filter(Follow.follower_id == user_id).scalar()
|
||||
is_following = db.query(Follow).filter(
|
||||
Follow.follower_id == current_user.id, Follow.following_id == user_id
|
||||
).first() is not None
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"avatar": user.avatar,
|
||||
"created_at": user.created_at,
|
||||
"post_count": post_count,
|
||||
"follower_count": follower_count,
|
||||
"following_count": following_count,
|
||||
"is_following": is_following,
|
||||
"is_self": current_user.id == user_id,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{user_id}/follow")
|
||||
def toggle_follow(
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""关注/取消关注"""
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(status_code=400, detail="不能关注自己")
|
||||
|
||||
target = db.query(User).filter(User.id == user_id).first()
|
||||
if not target:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
existing = db.query(Follow).filter(
|
||||
Follow.follower_id == current_user.id, Follow.following_id == user_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
db.commit()
|
||||
return {"followed": False}
|
||||
else:
|
||||
follow = Follow(follower_id=current_user.id, following_id=user_id)
|
||||
db.add(follow)
|
||||
# 创建通知
|
||||
notif = Notification(
|
||||
user_id=user_id,
|
||||
type="follow",
|
||||
content=f"{current_user.username} 关注了你",
|
||||
from_user_id=current_user.id,
|
||||
related_id=current_user.id,
|
||||
)
|
||||
db.add(notif)
|
||||
db.commit()
|
||||
return {"followed": True}
|
||||
|
||||
|
||||
@router.get("/{user_id}/posts")
|
||||
def get_user_posts(
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户的帖子列表"""
|
||||
query = db.query(Post).filter(Post.user_id == user_id)
|
||||
if user_id != current_user.id:
|
||||
query = query.filter(Post.is_public == True)
|
||||
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
username = user.username if user else "未知"
|
||||
avatar = user.avatar if user else ""
|
||||
|
||||
return [
|
||||
{
|
||||
"id": p.id, "title": p.title, "content": p.content[:200],
|
||||
"category": p.category, "tags": p.tags,
|
||||
"view_count": p.view_count, "like_count": p.like_count,
|
||||
"comment_count": p.comment_count, "collect_count": p.collect_count,
|
||||
"created_at": p.created_at, "updated_at": p.updated_at,
|
||||
"author": {"id": user_id, "username": username, "avatar": avatar},
|
||||
}
|
||||
for p in posts
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{user_id}/followers")
|
||||
def get_followers(
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取粉丝列表"""
|
||||
follows = (
|
||||
db.query(Follow)
|
||||
.filter(Follow.following_id == user_id)
|
||||
.order_by(Follow.created_at.desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
user_ids = [f.follower_id for f in follows]
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all() if user_ids else []
|
||||
user_map = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
{
|
||||
"id": uid,
|
||||
"username": user_map[uid].username if uid in user_map else "",
|
||||
"avatar": user_map[uid].avatar if uid in user_map else "",
|
||||
}
|
||||
for uid in user_ids
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{user_id}/following")
|
||||
def get_following(
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取关注列表"""
|
||||
follows = (
|
||||
db.query(Follow)
|
||||
.filter(Follow.follower_id == user_id)
|
||||
.order_by(Follow.created_at.desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
user_ids = [f.following_id for f in follows]
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all() if user_ids else []
|
||||
user_map = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
{
|
||||
"id": uid,
|
||||
"username": user_map[uid].username if uid in user_map else "",
|
||||
"avatar": user_map[uid].avatar if uid in user_map else "",
|
||||
}
|
||||
for uid in user_ids
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{user_id}/collects")
|
||||
def get_user_collects(
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户收藏的帖子"""
|
||||
collects = (
|
||||
db.query(Collect)
|
||||
.filter(Collect.user_id == user_id)
|
||||
.order_by(Collect.created_at.desc())
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
)
|
||||
post_ids = [c.post_id for c in collects]
|
||||
if not post_ids:
|
||||
return []
|
||||
|
||||
posts = db.query(Post).filter(Post.id.in_(post_ids)).all()
|
||||
post_map = {p.id: p for p in posts}
|
||||
author_ids = list(set(p.user_id for p in posts))
|
||||
authors = db.query(User).filter(User.id.in_(author_ids)).all()
|
||||
author_map = {a.id: a for a in authors}
|
||||
|
||||
result = []
|
||||
for pid in post_ids:
|
||||
p = post_map.get(pid)
|
||||
if not p:
|
||||
continue
|
||||
a = author_map.get(p.user_id)
|
||||
result.append({
|
||||
"id": p.id, "title": p.title, "content": p.content[:200],
|
||||
"category": p.category, "tags": p.tags,
|
||||
"view_count": p.view_count, "like_count": p.like_count,
|
||||
"comment_count": p.comment_count, "collect_count": p.collect_count,
|
||||
"created_at": p.created_at,
|
||||
"author": {"id": p.user_id, "username": a.username if a else "", "avatar": a.avatar if a else ""},
|
||||
})
|
||||
return result
|
||||
274
backend/routers/web_search.py
Normal file
274
backend/routers/web_search.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""联网搜索路由 - 使用豆包大模型 + 火山方舟 web_search"""
|
||||
import json
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from database import get_db
|
||||
from models.user import User
|
||||
from models.conversation import Conversation, Message
|
||||
from schemas.conversation import ConversationResponse, ConversationDetail, MessageResponse
|
||||
from routers.auth import get_current_user
|
||||
from config import ARK_API_KEY, ARK_ENDPOINT, ARK_BASE_URL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
WEB_SEARCH_SYSTEM_PROMPT = """你是一个智能联网搜索助手。你可以通过联网搜索获取最新信息来回答用户的问题。
|
||||
|
||||
回答要求:
|
||||
1. 基于搜索到的最新信息给出准确、详细的回答
|
||||
2. 使用清晰的 Markdown 格式组织内容
|
||||
3. 如果涉及时效性信息,注明信息的时间
|
||||
4. 对搜索结果进行整合和总结,而非简单罗列
|
||||
5. 如果搜索结果不足以回答问题,诚实告知并给出建议"""
|
||||
|
||||
|
||||
class WebSearchRequest(BaseModel):
|
||||
conversation_id: Optional[int] = None
|
||||
content: str
|
||||
model_config_id: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def web_search(
|
||||
request: WebSearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""联网搜索 - 流式输出"""
|
||||
# 获取模型配置:优先用指定的,否则用 .env 的 ARK 配置
|
||||
from services.ai_service import _get_db_model_config
|
||||
ark_api_key = ARK_API_KEY
|
||||
ark_endpoint = ARK_ENDPOINT
|
||||
ark_base_url = ARK_BASE_URL
|
||||
search_count = 5
|
||||
|
||||
if request.model_config_id:
|
||||
cfg = _get_db_model_config("", request.model_config_id)
|
||||
if cfg:
|
||||
ark_api_key = cfg["api_key"]
|
||||
ark_endpoint = cfg["model"]
|
||||
ark_base_url = cfg["base_url"] or ARK_BASE_URL
|
||||
search_count = cfg.get("web_search_count", 5)
|
||||
|
||||
if not ark_api_key:
|
||||
raise HTTPException(status_code=500, detail="未配置火山方舟 API Key")
|
||||
|
||||
# 创建或获取对话
|
||||
if request.conversation_id:
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == request.conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
else:
|
||||
conv = Conversation(
|
||||
user_id=current_user.id,
|
||||
title=request.content[:50] if request.content else "新搜索",
|
||||
type="web_search",
|
||||
)
|
||||
db.add(conv)
|
||||
db.commit()
|
||||
db.refresh(conv)
|
||||
|
||||
# 保存用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.content,
|
||||
)
|
||||
db.add(user_msg)
|
||||
db.commit()
|
||||
|
||||
# 构建历史消息
|
||||
history_msgs = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": WEB_SEARCH_SYSTEM_PROMPT}]
|
||||
for msg in history_msgs:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# 流式调用火山方舟 API
|
||||
url = f"{ark_base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {ark_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def generate():
|
||||
full_response = ""
|
||||
try:
|
||||
payload_with_search = {
|
||||
"model": ark_endpoint,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"tools": [
|
||||
{
|
||||
"type": "web_search",
|
||||
"web_search": {"enable": True, "search_result_count": max(1, min(50, search_count))},
|
||||
}
|
||||
],
|
||||
}
|
||||
payload_without_search = {
|
||||
"model": ark_endpoint,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# 先尝试带联网搜索调用
|
||||
use_fallback = False
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", url, json=payload_with_search, headers=headers) as resp:
|
||||
if resp.status_code == 400:
|
||||
# 模型不支持 web_search tools,降级到普通调用
|
||||
await resp.aread()
|
||||
use_fallback = True
|
||||
elif resp.status_code != 200:
|
||||
error_body = await resp.aread()
|
||||
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
full_response = error_msg
|
||||
else:
|
||||
buffer = ""
|
||||
async for chunk in resp.aiter_text():
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
full_response += content
|
||||
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 降级:不带 web_search tools 重试
|
||||
if use_fallback:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", url, json=payload_without_search, headers=headers) as resp:
|
||||
if resp.status_code != 200:
|
||||
error_body = await resp.aread()
|
||||
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
full_response = error_msg
|
||||
else:
|
||||
buffer = ""
|
||||
async for chunk in resp.aiter_text():
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
full_response += content
|
||||
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
error_msg = f"联网搜索出错: {str(e)}"
|
||||
if not full_response:
|
||||
full_response = error_msg
|
||||
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
|
||||
|
||||
# 保存AI回复
|
||||
if full_response:
|
||||
ai_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="assistant",
|
||||
content=full_response,
|
||||
)
|
||||
db.add(ai_msg)
|
||||
db.commit()
|
||||
|
||||
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||
def get_conversations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取联网搜索对话列表"""
|
||||
conversations = (
|
||||
db.query(Conversation)
|
||||
.filter(Conversation.user_id == current_user.id, Conversation.type == "web_search")
|
||||
.order_by(Conversation.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [ConversationResponse.model_validate(c) for c in conversations]
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
|
||||
def get_conversation_detail(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取对话详情"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
msgs = (
|
||||
db.query(Message)
|
||||
.filter(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
result = ConversationDetail.model_validate(conv)
|
||||
result.messages = [MessageResponse.model_validate(m) for m in msgs]
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
def delete_conversation(
|
||||
conversation_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除对话"""
|
||||
conv = db.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
).first()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
|
||||
db.delete(conv)
|
||||
db.commit()
|
||||
return {"message": "删除成功"}
|
||||
Reference in New Issue
Block a user