Files
bianchengshequ/backend/routers/posts.py

441 lines
14 KiB
Python

"""经验知识库路由"""
import json
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import or_
from database import get_db
from models.user import User
from models.post import Post
from models.comment import Comment
from models.like import Like, Collect
from models.follow import Follow
from models.notification import Notification
from models.attachment import Attachment
from schemas.post import (
PostCreate, PostUpdate, PostResponse, PostListResponse,
CommentCreate, CommentResponse,
)
from routers.auth import get_current_user
router = APIRouter()
import re
def _extract_cover_image(content: str) -> str:
"""从Markdown内容中提取第一张图片作为封面"""
if not content:
return ""
match = re.search(r'!\[.*?\]\((.*?)\)', content)
if match:
return match.group(1)
img_match = re.search(r'<img[^>]+src=["\']([^"\']+)["\']', content)
if img_match:
return img_match.group(1)
return ""
def _enrich_post_with_author(post: Post, db: Session) -> dict:
"""为帖子附加作者信息(用于信息流)"""
author = db.query(User).filter(User.id == post.user_id).first()
return {
"id": post.id, "title": post.title, "content": post.content[:200],
"category": post.category, "tags": post.tags,
"cover_image": _extract_cover_image(post.content),
"view_count": post.view_count, "like_count": post.like_count,
"comment_count": post.comment_count, "collect_count": post.collect_count,
"created_at": post.created_at, "updated_at": post.updated_at,
"author": {
"id": post.user_id,
"username": author.username if author else "未知",
"avatar": author.avatar if author else "",
},
}
@router.get("/feed")
def get_feed(
page: int = 1, page_size: int = 20,
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""关注的人的帖子流"""
following_ids = [f.following_id for f in db.query(Follow.following_id).filter(Follow.follower_id == current_user.id).all()]
if not following_ids:
return {"items": [], "total": 0, "page": page, "page_size": page_size}
query = db.query(Post).filter(Post.user_id.in_(following_ids), Post.is_public == True, Post.is_draft == False)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = (
query.order_by(Post.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
@router.get("/hot")
def get_hot_posts(
page: int = 1, page_size: int = 20,
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""热门帖子(按热度排序)"""
query = db.query(Post).filter(Post.is_public == True, Post.is_draft == False)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = (
query.order_by((Post.like_count * 3 + Post.comment_count * 2 + Post.view_count).desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
@router.get("/latest")
def get_latest_posts(
page: int = 1, page_size: int = 20,
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""最新帖子"""
query = db.query(Post).filter(Post.is_public == True, Post.is_draft == False)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = (
query.order_by(Post.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
@router.get("/drafts")
def get_drafts(
page: int = 1, page_size: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户的草稿列表"""
query = db.query(Post).filter(Post.user_id == current_user.id, Post.is_draft == True)
total = query.count()
posts = query.order_by(Post.updated_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return {
"items": [_enrich_post(p, db, current_user.id) for p in posts],
"total": total, "page": page, "page_size": page_size,
}
def _enrich_post(post: Post, db: Session, current_user_id: int = None) -> PostResponse:
"""填充帖子额外字段"""
author = db.query(User).filter(User.id == post.user_id).first()
result = PostResponse.model_validate(post)
result.author_name = author.username if author else "未知用户"
if current_user_id:
result.is_liked = db.query(Like).filter(
Like.post_id == post.id, Like.user_id == current_user_id
).first() is not None
result.is_collected = db.query(Collect).filter(
Collect.post_id == post.id, Collect.user_id == current_user_id
).first() is not None
return result
@router.get("", response_model=PostListResponse)
def get_posts(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
category: Optional[str] = None,
tag: Optional[str] = None,
user_id: Optional[int] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取帖子列表"""
query = db.query(Post).filter(
or_(Post.is_public == True, Post.user_id == current_user.id),
Post.is_draft == False,
)
if category:
query = query.filter(Post.category == category)
if tag:
query = query.filter(Post.tags.contains(tag))
if user_id:
query = query.filter(Post.user_id == user_id)
total = query.count()
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return PostListResponse(
items=[_enrich_post(p, db, current_user.id) for p in posts],
total=total,
page=page,
page_size=page_size,
)
@router.post("", response_model=PostResponse)
def create_post(
data: PostCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""发布经验帖"""
post = Post(
user_id=current_user.id,
title=data.title,
content=data.content,
category=data.category,
tags=json.dumps(data.tags, ensure_ascii=False),
is_public=data.is_public,
is_draft=data.is_draft,
)
db.add(post)
db.commit()
db.refresh(post)
return _enrich_post(post, db, current_user.id)
@router.get("/{post_id}", response_model=PostResponse)
def get_post(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取帖子详情"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
if not post.is_public and post.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问")
# 增加浏览量
post.view_count += 1
db.commit()
db.refresh(post)
return _enrich_post(post, db, current_user.id)
@router.put("/{post_id}", response_model=PostResponse)
def update_post(
post_id: int,
data: PostUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""编辑帖子"""
# 管理员可以编辑任意帖子,普通用户只能编辑自己的
if current_user.is_admin:
post = db.query(Post).filter(Post.id == post_id).first()
else:
post = db.query(Post).filter(Post.id == post_id, Post.user_id == current_user.id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在或无权编辑")
if data.title is not None:
post.title = data.title
if data.content is not None:
post.content = data.content
if data.category is not None:
post.category = data.category
if data.tags is not None:
post.tags = json.dumps(data.tags, ensure_ascii=False)
if data.is_public is not None:
post.is_public = data.is_public
if data.is_draft is not None:
post.is_draft = data.is_draft
db.commit()
db.refresh(post)
return _enrich_post(post, db, current_user.id)
@router.delete("/{post_id}")
def delete_post(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除帖子"""
post = db.query(Post).filter(Post.id == post_id, Post.user_id == current_user.id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在或无权删除")
db.query(Comment).filter(Comment.post_id == post_id).delete()
db.query(Like).filter(Like.post_id == post_id).delete()
db.query(Collect).filter(Collect.post_id == post_id).delete()
db.query(Attachment).filter(Attachment.post_id == post_id).delete()
db.delete(post)
db.commit()
return {"message": "删除成功"}
@router.post("/{post_id}/like")
def toggle_like(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""点赞/取消点赞"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
existing = db.query(Like).filter(
Like.post_id == post_id, Like.user_id == current_user.id
).first()
if existing:
db.delete(existing)
post.like_count = max(0, post.like_count - 1)
db.commit()
return {"liked": False, "like_count": post.like_count}
else:
db.add(Like(post_id=post_id, user_id=current_user.id))
post.like_count += 1
# 通知帖子作者
if post.user_id != current_user.id:
db.add(Notification(
user_id=post.user_id, type="like",
content=f"{current_user.username} 赞了你的文章「{post.title[:30]}",
from_user_id=current_user.id, related_id=post_id,
))
db.commit()
return {"liked": True, "like_count": post.like_count}
@router.post("/{post_id}/collect")
def toggle_collect(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""收藏/取消收藏"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
existing = db.query(Collect).filter(
Collect.post_id == post_id, Collect.user_id == current_user.id
).first()
if existing:
db.delete(existing)
post.collect_count = max(0, post.collect_count - 1)
db.commit()
return {"collected": False, "collect_count": post.collect_count}
else:
db.add(Collect(post_id=post_id, user_id=current_user.id))
post.collect_count += 1
db.commit()
return {"collected": True, "collect_count": post.collect_count}
@router.get("/{post_id}/comments", response_model=List[CommentResponse])
def get_comments(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取评论列表"""
comments = (
db.query(Comment)
.filter(Comment.post_id == post_id)
.order_by(Comment.created_at.asc())
.all()
)
results = []
for c in comments:
author = db.query(User).filter(User.id == c.user_id).first()
r = CommentResponse.model_validate(c)
r.author_name = author.username if author else "未知用户"
results.append(r)
return results
@router.post("/{post_id}/comments", response_model=CommentResponse)
def create_comment(
post_id: int,
data: CommentCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""发表评论"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
comment = Comment(
post_id=post_id,
user_id=current_user.id,
content=data.content,
)
db.add(comment)
post.comment_count += 1
# 通知帖子作者
if post.user_id != current_user.id:
db.add(Notification(
user_id=post.user_id, type="comment",
content=f"{current_user.username} 评论了你的文章「{post.title[:30]}",
from_user_id=current_user.id, related_id=post_id,
))
db.commit()
db.refresh(comment)
result = CommentResponse.model_validate(comment)
result.author_name = current_user.username
return result
@router.get("/{post_id}/attachments")
def get_attachments(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取帖子的附件列表"""
attachments = (
db.query(Attachment)
.filter(Attachment.post_id == post_id)
.order_by(Attachment.created_at.asc())
.all()
)
return [
{
"id": a.id,
"filename": a.filename,
"url": a.url,
"file_size": a.file_size,
"file_type": a.file_type,
"created_at": a.created_at,
}
for a in attachments
]
@router.delete("/{post_id}/attachments/{attachment_id}")
def delete_attachment(
post_id: int,
attachment_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除附件(仅作者)"""
attachment = db.query(Attachment).filter(
Attachment.id == attachment_id,
Attachment.post_id == post_id,
Attachment.user_id == current_user.id,
).first()
if not attachment:
raise HTTPException(status_code=404, detail="附件不存在或无权删除")
db.delete(attachment)
db.commit()
return {"message": "删除成功"}