初始提交:极码 GeekCode 全栈项目(FastAPI + Vue3)

This commit is contained in:
2026-04-12 10:12:18 +08:00
commit 6aecef16f6
104 changed files with 21009 additions and 0 deletions

65
backend/config.py Normal file
View File

@@ -0,0 +1,65 @@
"""应用配置"""
import os
from dotenv import load_dotenv
load_dotenv()
# 数据库配置
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+pymysql://root:root@127.0.0.1:3306/biancheng?charset=utf8mb4")
# JWT配置
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天
# 上传文件配置
UPLOAD_DIR = os.path.join(os.path.dirname(__file__), "uploads")
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10MB
MAX_ATTACHMENT_SIZE = 20 * 1024 * 1024 # 20MB
# 腾讯云COS配置
COS_SECRET_ID = os.getenv("COS_SECRET_ID", "")
COS_SECRET_KEY = os.getenv("COS_SECRET_KEY", "")
COS_BUCKET = os.getenv("COS_BUCKET", "") # 如 bianchengshequ-1250000000
COS_REGION = os.getenv("COS_REGION", "") # 如 ap-beijing
COS_CUSTOM_DOMAIN = os.getenv("COS_CUSTOM_DOMAIN", "") # 可选CDN自定义域名
# 大模型配置
MODEL_CONFIG = {
"multimodal": {
"provider": "google",
"model": "gemini-2.5-pro-preview-06-05",
"description": "图片/草图理解",
},
"reasoning": {
"provider": "anthropic",
"model": "claude-sonnet-4-20250514",
"description": "需求解读/架构分析",
},
"lightweight": {
"provider": "openai",
"model": "gpt-4o-mini",
"description": "分类/标签/轻量任务",
},
"knowledge_base": {
"provider": "deepseek",
"model": "deepseek-chat",
"description": "知识库文档理解/问答",
},
"embedding": {
"provider": "openai",
"model": "text-embedding-3-large",
"description": "向量化嵌入",
},
}
# API Key 配置
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
# 火山方舟(豆包大模型)配置
ARK_API_KEY = os.getenv("ARK_API_KEY", "")
ARK_ENDPOINT = os.getenv("ARK_ENDPOINT", "ep-20260411180700-z6nll")
ARK_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"

34
backend/database.py Normal file
View File

@@ -0,0 +1,34 @@
"""数据库连接配置"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from config import DATABASE_URL
# 根据数据库类型配置引擎参数
if "sqlite" in DATABASE_URL:
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
else:
engine = create_engine(
DATABASE_URL,
pool_size=10,
max_overflow=20,
pool_recycle=3600,
pool_pre_ping=True,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
"""获取数据库会话的依赖注入"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
"""初始化数据库,创建所有表"""
Base.metadata.create_all(bind=engine)

174
backend/main.py Normal file
View File

@@ -0,0 +1,174 @@
"""FastAPI 应用入口"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import os
from database import init_db
from config import UPLOAD_DIR
import models.system_config # 确保 system_configs 表被创建
import models.category # 确保 categories 表被创建
import models.nav_category # 确保 nav_categories 表被创建
import models.nav_link # 确保 nav_links 表被创建
import models.project # 确保 projects 表被创建
import models.shared_api # 确保 shared_api 相关表被创建
import models.knowledge_base # 确保 kb 相关表被创建
import models.attachment # 确保 attachments 表被创建
from routers import auth, requirement, architecture, posts, search, ai_models, bookmarks, users, notifications, upload, admin, nav, projects, shared_api, knowledge_base, web_search, ai_format
# 确保上传目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
app = FastAPI(
title="极码 GeekCode",
description="极码 GeekCode - 开发者社区、AI工具库、经验知识库",
version="2.0.0",
)
# CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 静态文件(上传的图片等)
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
# 注册路由
app.include_router(auth.router, prefix="/api/auth", tags=["认证"])
app.include_router(requirement.router, prefix="/api/requirement", tags=["需求理解助手"])
app.include_router(architecture.router, prefix="/api/architecture", tags=["架构选型助手"])
app.include_router(posts.router, prefix="/api/posts", tags=["经验知识库"])
app.include_router(search.router, prefix="/api/search", tags=["搜索"])
app.include_router(ai_models.router)
app.include_router(ai_models.public_router)
app.include_router(bookmarks.router, prefix="/api/bookmarks", tags=["网站收藏"])
app.include_router(users.router, prefix="/api/users", tags=["用户"])
app.include_router(notifications.router, prefix="/api/notifications", tags=["消息通知"])
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
app.include_router(admin.router, prefix="/api/admin", tags=["后台管理"])
app.include_router(nav.router, prefix="/api/nav", tags=["导航站"])
app.include_router(projects.router, prefix="/api/projects", tags=["开源项目"])
app.include_router(shared_api.router, prefix="/api/api-hub", tags=["API Hub"])
app.include_router(knowledge_base.router, prefix="/api/kb", tags=["团队知识库"])
app.include_router(web_search.router, prefix="/api/web-search", tags=["联网搜索"])
app.include_router(ai_format.router)
@app.on_event("startup")
async def startup():
"""应用启动时初始化数据库"""
init_db()
_init_default_categories()
_migrate_user_is_approved()
_migrate_project_collect_count()
_migrate_web_search_enabled()
_migrate_web_search_count()
def _init_default_categories():
"""如果分类表为空,插入默认分类"""
from database import SessionLocal
from models.category import Category
db = SessionLocal()
try:
if db.query(Category).count() == 0:
defaults = ['前端', '后端', '部署', '踩坑', '最佳实践', '工具']
for i, name in enumerate(defaults):
db.add(Category(name=name, sort_order=i))
db.commit()
finally:
db.close()
def _migrate_user_is_approved():
"""迁移:给 users 表添加 is_approved 字段(已有用户自动设为已审核)"""
from database import SessionLocal
from sqlalchemy import text
db = SessionLocal()
try:
# 检查字段是否存在
result = db.execute(text(
"SELECT COUNT(*) FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'users' AND COLUMN_NAME = 'is_approved'"
))
exists = result.scalar()
if not exists:
# 添加字段先默认值1让已有用户自动通过审核
db.execute(text("ALTER TABLE users ADD COLUMN is_approved TINYINT(1) NOT NULL DEFAULT 1"))
# 再把默认值改为0新用户需审核
db.execute(text("ALTER TABLE users ALTER COLUMN is_approved SET DEFAULT 0"))
db.commit()
except Exception as e:
db.rollback()
print(f"[migrate] is_approved: {e}")
finally:
db.close()
@app.get("/api/health")
async def health_check():
"""健康检查"""
return {"status": "ok", "message": "极码 GeekCode API 运行中"}
def _migrate_project_collect_count():
"""迁移:给 projects 表添加 collect_count 字段"""
from database import SessionLocal
from sqlalchemy import text
db = SessionLocal()
try:
result = db.execute(text(
"SELECT COUNT(*) FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'projects' AND COLUMN_NAME = 'collect_count'"
))
if not result.scalar():
db.execute(text("ALTER TABLE projects ADD COLUMN collect_count INT NOT NULL DEFAULT 0"))
db.commit()
except Exception as e:
db.rollback()
print(f"[migrate] project collect_count: {e}")
finally:
db.close()
def _migrate_web_search_enabled():
"""迁移:给 ai_model_configs 表添加 web_search_enabled 字段"""
from database import SessionLocal
from sqlalchemy import text
db = SessionLocal()
try:
result = db.execute(text(
"SELECT COUNT(*) FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'ai_model_configs' AND COLUMN_NAME = 'web_search_enabled'"
))
if not result.scalar():
db.execute(text("ALTER TABLE ai_model_configs ADD COLUMN web_search_enabled TINYINT(1) NOT NULL DEFAULT 0"))
db.commit()
except Exception as e:
db.rollback()
print(f"[migrate] web_search_enabled: {e}")
finally:
db.close()
def _migrate_web_search_count():
"""迁移:给 ai_model_configs 表添加 web_search_count 字段"""
from database import SessionLocal
from sqlalchemy import text
db = SessionLocal()
try:
result = db.execute(text(
"SELECT COUNT(*) FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'ai_model_configs' AND COLUMN_NAME = 'web_search_count'"
))
if not result.scalar():
db.execute(text("ALTER TABLE ai_model_configs ADD COLUMN web_search_count INT NOT NULL DEFAULT 5"))
db.commit()
except Exception as e:
db.rollback()
print(f"[migrate] web_search_count: {e}")
finally:
db.close()

View File

@@ -0,0 +1,9 @@
from models.user import User
from models.conversation import Conversation, Message
from models.post import Post
from models.comment import Comment
from models.like import Like, Collect
from models.ai_model import AIModelConfig
from models.bookmark import BookmarkSite
from models.follow import Follow
from models.notification import Notification

View File

@@ -0,0 +1,25 @@
"""AI模型配置模型"""
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text
from sqlalchemy.sql import func
from database import Base
class AIModelConfig(Base):
"""AI模型配置表 - 存储各AI服务商的模型信息和API Key"""
__tablename__ = "ai_model_configs"
id = Column(Integer, primary_key=True, index=True)
provider = Column(String(50), nullable=False) # openai/anthropic/google/deepseek
provider_name = Column(String(100), default="") # 显示名称
model_id = Column(String(100), nullable=False) # 模型标识符
model_name = Column(String(100), default="") # 模型显示名称
api_key = Column(String(500), default="") # API Key
base_url = Column(String(500), default="") # 自定义API地址
task_type = Column(String(50), default="") # 任务类型: multimodal/reasoning/lightweight/embedding
is_enabled = Column(Boolean, default=True) # 是否启用
is_default = Column(Boolean, default=False) # 是否为该任务类型的默认模型
web_search_enabled = Column(Boolean, default=False) # 是否启用联网搜索(仅豆包/火山方舟支持)
web_search_count = Column(Integer, default=5) # 联网搜索结果条数1-50默认5
description = Column(Text, default="") # 描述说明
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

View File

@@ -0,0 +1,18 @@
"""附件模型"""
from sqlalchemy import Column, Integer, String, BigInteger, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class Attachment(Base):
__tablename__ = "attachments"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(Integer, nullable=True, default=None, index=True) # 新建文章时为null发布后回填
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
filename = Column(String(255), nullable=False) # 原始文件名
storage_key = Column(String(500), nullable=False) # COS对象键
url = Column(String(500), nullable=False) # 完整访问URL
file_size = Column(BigInteger, nullable=False) # 文件大小(字节)
file_type = Column(String(100), nullable=False) # MIME类型
created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,16 @@
"""网站收藏模型"""
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class BookmarkSite(Base):
__tablename__ = "bookmark_sites"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
name = Column(String(100), nullable=False)
url = Column(String(500), nullable=False)
icon = Column(String(500), default="")
sort_order = Column(Integer, default=0)
created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,13 @@
"""帖子分类模型"""
from sqlalchemy import Column, Integer, String, Boolean
from database import Base
class Category(Base):
"""帖子分类表"""
__tablename__ = "categories"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
sort_order = Column(Integer, default=0) # 排序,越小越靠前
is_active = Column(Boolean, default=True) # 是否启用

14
backend/models/comment.py Normal file
View File

@@ -0,0 +1,14 @@
"""评论模型"""
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class Comment(Base):
__tablename__ = "comments"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(Integer, ForeignKey("posts.id"), nullable=False, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
content = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,33 @@
"""对话模型 - 用于需求助手和架构助手"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum
from sqlalchemy.sql import func
from database import Base
import enum
class ConversationType(str, enum.Enum):
"""对话类型"""
REQUIREMENT = "requirement" # 需求理解
ARCHITECTURE = "architecture" # 架构选型
class Conversation(Base):
__tablename__ = "conversations"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(200), default="新对话")
type = Column(String(20), nullable=False) # requirement / architecture
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
class Message(Base):
__tablename__ = "messages"
id = Column(Integer, primary_key=True, index=True)
conversation_id = Column(Integer, ForeignKey("conversations.id"), nullable=False, index=True)
role = Column(String(20), nullable=False) # user / assistant
content = Column(Text, nullable=False)
image_urls = Column(Text, default="") # JSON数组存储图片路径
created_at = Column(DateTime(timezone=True), server_default=func.now())

17
backend/models/follow.py Normal file
View File

@@ -0,0 +1,17 @@
"""关注关系模型"""
from sqlalchemy import Column, Integer, DateTime, ForeignKey, UniqueConstraint
from sqlalchemy.sql import func
from database import Base
class Follow(Base):
__tablename__ = "follows"
id = Column(Integer, primary_key=True, index=True)
follower_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
following_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
UniqueConstraint("follower_id", "following_id", name="uq_follow"),
)

View File

@@ -0,0 +1,42 @@
"""团队知识库模型"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class KbCategory(Base):
"""知识库分类"""
__tablename__ = "kb_categories"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, unique=True)
icon = Column(String(500), default="")
sort_order = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class KbItem(Base):
"""知识库条目(关联帖子)"""
__tablename__ = "kb_items"
id = Column(Integer, primary_key=True, index=True)
category_id = Column(Integer, ForeignKey("kb_categories.id"), nullable=True, index=True)
post_id = Column(Integer, ForeignKey("posts.id"), nullable=False, index=True)
title = Column(String(200), nullable=False)
summary = Column(Text, default="")
sort_order = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
added_by = Column(Integer, ForeignKey("users.id"), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class KbAccessLog(Base):
"""知识库访问日志"""
__tablename__ = "kb_access_logs"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
action = Column(String(20), default="view") # view / search / ai_chat
query = Column(Text, default="")
created_at = Column(DateTime(timezone=True), server_default=func.now())

44
backend/models/like.py Normal file
View File

@@ -0,0 +1,44 @@
"""点赞/收藏模型"""
from sqlalchemy import Column, Integer, DateTime, ForeignKey, UniqueConstraint
from sqlalchemy.sql import func
from database import Base
class Like(Base):
__tablename__ = "likes"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(Integer, ForeignKey("posts.id"), nullable=False, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
UniqueConstraint("post_id", "user_id", name="uq_like_post_user"),
)
class Collect(Base):
__tablename__ = "collects"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(Integer, ForeignKey("posts.id"), nullable=False, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
UniqueConstraint("post_id", "user_id", name="uq_collect_post_user"),
)
class ProjectCollect(Base):
"""开源项目收藏"""
__tablename__ = "project_collects"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
UniqueConstraint("project_id", "user_id", name="uq_project_collect_user"),
)

View File

@@ -0,0 +1,15 @@
"""导航分类模型"""
from sqlalchemy import Column, Integer, String, Boolean, DateTime
from sqlalchemy.sql import func
from database import Base
class NavCategory(Base):
__tablename__ = "nav_categories"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
icon = Column(String(500), default="")
sort_order = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,22 @@
"""导航链接模型"""
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class NavLink(Base):
__tablename__ = "nav_links"
id = Column(Integer, primary_key=True, index=True)
category_id = Column(Integer, ForeignKey("nav_categories.id"), nullable=False, index=True)
name = Column(String(100), nullable=False)
url = Column(String(500), nullable=False)
icon = Column(String(500), default="")
description = Column(String(200), default="")
sort_order = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
# 审核相关: approved=已通过, pending=待审核, rejected=已拒绝
status = Column(String(20), default="approved", nullable=False, index=True)
submitted_by = Column(Integer, ForeignKey("users.id"), nullable=True)
reject_reason = Column(String(200), default="")
created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,17 @@
"""消息通知模型"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class Notification(Base):
__tablename__ = "notifications"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
type = Column(String(20), nullable=False) # like / comment / follow / system
content = Column(Text, default="")
related_id = Column(Integer, default=None) # 关联的帖子/用户ID
from_user_id = Column(Integer, ForeignKey("users.id"), default=None) # 触发通知的用户
is_read = Column(Boolean, default=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())

23
backend/models/post.py Normal file
View File

@@ -0,0 +1,23 @@
"""经验帖模型"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey
from sqlalchemy.sql import func
from database import Base
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(200), nullable=False)
content = Column(Text, nullable=False)
category = Column(String(50), default="") # 分类:前端/后端/部署/踩坑/最佳实践
tags = Column(Text, default="") # JSON数组存储标签
is_public = Column(Boolean, default=True)
is_draft = Column(Boolean, default=False, index=True) # 草稿状态
view_count = Column(Integer, default=0)
like_count = Column(Integer, default=0)
collect_count = Column(Integer, default=0)
comment_count = Column(Integer, default=0)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

24
backend/models/project.py Normal file
View File

@@ -0,0 +1,24 @@
"""开源项目模型"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime
from sqlalchemy.sql import func
from database import Base
class Project(Base):
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
description = Column(Text, default="")
url = Column(String(500), nullable=False)
homepage = Column(String(500), default="")
icon = Column(String(500), default="")
language = Column(String(50), default="")
category = Column(String(50), default="", index=True)
stars = Column(Integer, default=0)
forks = Column(Integer, default=0)
collect_count = Column(Integer, default=0)
sort_order = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

View File

@@ -0,0 +1,59 @@
"""共享API管理模型"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from database import Base
class SharedApiCategory(Base):
"""API分类"""
__tablename__ = "shared_api_categories"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, unique=True)
icon = Column(String(500), default="")
sort_order = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class SharedApi(Base):
"""共享API"""
__tablename__ = "shared_apis"
id = Column(Integer, primary_key=True, index=True)
category_id = Column(Integer, ForeignKey("shared_api_categories.id"), nullable=True, index=True)
name = Column(String(200), nullable=False)
description = Column(Text, default="")
base_url = Column(String(500), nullable=False)
doc_url = Column(String(500), default="")
# 认证方式: none / api_key / bearer / basic
auth_type = Column(String(20), default="none")
# 加密存储的API Key
api_key_encrypted = Column(Text, default="")
# API Key 放在哪个请求头中, 如 Authorization, X-API-Key
api_key_header = Column(String(100), default="Authorization")
# 健康检查
health_check_url = Column(String(500), default="")
last_check_time = Column(DateTime(timezone=True), nullable=True)
last_check_status = Column(String(20), default="unknown") # ok / error / unknown
# 元信息
added_by = Column(Integer, ForeignKey("users.id"), nullable=True)
tags = Column(String(500), default="") # 逗号分隔
call_count = Column(Integer, default=0)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
class SharedApiLog(Base):
"""API使用日志"""
__tablename__ = "shared_api_logs"
id = Column(Integer, primary_key=True, index=True)
api_id = Column(Integer, ForeignKey("shared_apis.id"), nullable=False, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
action = Column(String(20), default="test") # test / health_check
request_url = Column(String(500), default="")
response_status = Column(Integer, nullable=True)
response_time_ms = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,15 @@
"""系统配置模型 - 键值对存储"""
from sqlalchemy import Column, Integer, String, Text, DateTime
from sqlalchemy.sql import func
from database import Base
class SystemConfig(Base):
"""系统配置表 - 存储OSS等系统级配置"""
__tablename__ = "system_configs"
id = Column(Integer, primary_key=True, index=True)
key = Column(String(100), unique=True, index=True, nullable=False)
value = Column(Text, default="")
description = Column(String(200), default="")
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

18
backend/models/user.py Normal file
View File

@@ -0,0 +1,18 @@
"""用户模型"""
from sqlalchemy import Column, Integer, String, DateTime, Boolean
from sqlalchemy.sql import func
from database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True, nullable=False)
email = Column(String(100), unique=True, index=True, nullable=False)
password_hash = Column(String(200), nullable=False)
avatar = Column(String(500), default="")
is_admin = Column(Boolean, default=False, nullable=False)
is_banned = Column(Boolean, default=False, nullable=False)
is_approved = Column(Boolean, default=False, nullable=False) # 新用户需管理员审核
created_at = Column(DateTime(timezone=True), server_default=func.now())

18
backend/requirements.txt Normal file
View File

@@ -0,0 +1,18 @@
fastapi==0.115.0
uvicorn[standard]==0.30.0
sqlalchemy==2.0.35
alembic==1.13.0
python-multipart==0.0.9
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==4.0.1
openai==1.51.0
anthropic==0.34.0
google-generativeai==0.8.0
httpx==0.27.0
python-dotenv==1.0.1
aiofiles==24.1.0
Pillow==10.4.0
markdown-it-py==3.0.0
oss2==2.19.1
pymysql==1.1.2

View File

458
backend/routers/admin.py Normal file
View File

@@ -0,0 +1,458 @@
"""后台管理路由"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import func as sa_func, distinct
from datetime import datetime, timedelta
from pydantic import BaseModel
from typing import Optional
from database import get_db
from models.user import User
from models.post import Post
from models.comment import Comment
from models.like import Like, Collect
from models.system_config import SystemConfig
from models.category import Category
from routers.auth import get_admin_user, get_current_user
router = APIRouter()
# ---------- 对象存储配置管理腾讯云COS ----------
COS_CONFIG_KEYS = [
{"key": "cos_secret_id", "description": "SecretId"},
{"key": "cos_secret_key", "description": "SecretKey"},
{"key": "cos_bucket", "description": "Bucket如 bianchengshequ-1250000000"},
{"key": "cos_region", "description": "Region如 ap-beijing"},
{"key": "cos_custom_domain", "description": "自定义域名可选CDN加速域名"},
]
def get_cos_config_from_db(db: Session) -> dict:
"""从数据库读取COS配置"""
config = {}
for item in COS_CONFIG_KEYS:
row = db.query(SystemConfig).filter(SystemConfig.key == item["key"]).first()
config[item["key"]] = row.value if row else ""
return config
class CosConfigUpdate(BaseModel):
cos_secret_id: str = ""
cos_secret_key: Optional[str] = None
cos_bucket: str = ""
cos_region: str = ""
cos_custom_domain: str = ""
@router.get("/storage/config")
async def get_storage_config(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取对象存储配置"""
config = get_cos_config_from_db(db)
# 脱敏 SecretKey
secret = config.get("cos_secret_key", "")
if secret and len(secret) > 6:
config["cos_secret_key_masked"] = secret[:3] + "*" * (len(secret) - 6) + secret[-3:]
else:
config["cos_secret_key_masked"] = "*" * len(secret) if secret else ""
config.pop("cos_secret_key", None)
return {"config": config, "fields": COS_CONFIG_KEYS}
@router.put("/storage/config")
async def update_storage_config(
data: CosConfigUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""更新对象存储配置"""
updates = data.dict(exclude_none=True)
for key, value in updates.items():
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if row:
row.value = value
else:
desc = next((i["description"] for i in COS_CONFIG_KEYS if i["key"] == key), "")
db.add(SystemConfig(key=key, value=value, description=desc))
db.commit()
return {"message": "配置已保存"}
@router.post("/storage/test")
async def test_storage_connection(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""测试COS连接"""
config = get_cos_config_from_db(db)
secret_id = config.get("cos_secret_id", "")
secret_key = config.get("cos_secret_key", "")
bucket = config.get("cos_bucket", "")
region = config.get("cos_region", "")
if not all([secret_id, secret_key, bucket, region]):
raise HTTPException(status_code=400, detail="COS配置不完整请先填写所有必填项")
try:
from qcloud_cos import CosConfig, CosS3Client
cos_config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
client = CosS3Client(cos_config)
# 尝试获取bucket信息来验证连接
client.head_bucket(Bucket=bucket)
return {"success": True, "message": "连接成功"}
except ImportError:
raise HTTPException(status_code=500, detail="服务器未安装 cos-python-sdk-v5 库,请执行 pip install cos-python-sdk-v5")
except Exception as e:
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
@router.get("/stats")
async def get_stats(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取管理后台统计数据"""
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
# 基础统计
total_users = db.query(sa_func.count(User.id)).scalar() or 0
total_posts = db.query(sa_func.count(Post.id)).scalar() or 0
total_comments = db.query(sa_func.count(Comment.id)).scalar() or 0
total_likes = db.query(sa_func.count(Like.id)).scalar() or 0
# 今日新增
today_users = db.query(sa_func.count(User.id)).filter(User.created_at >= today).scalar() or 0
today_posts = db.query(sa_func.count(Post.id)).filter(Post.created_at >= today).scalar() or 0
# 今日活跃(今日有发帖/评论/点赞行为的用户)
active_post = db.query(distinct(Post.user_id)).filter(Post.created_at >= today)
active_comment = db.query(distinct(Comment.user_id)).filter(Comment.created_at >= today)
active_like = db.query(distinct(Like.user_id)).filter(Like.created_at >= today)
active_ids = set()
for row in active_post.all():
active_ids.add(row[0])
for row in active_comment.all():
active_ids.add(row[0])
for row in active_like.all():
active_ids.add(row[0])
today_active = len(active_ids)
# 7日趋势
user_trend = []
post_trend = []
for i in range(6, -1, -1):
day_start = today - timedelta(days=i)
day_end = day_start + timedelta(days=1)
date_str = day_start.strftime("%m-%d")
u_count = db.query(sa_func.count(User.id)).filter(
User.created_at >= day_start, User.created_at < day_end
).scalar() or 0
p_count = db.query(sa_func.count(Post.id)).filter(
Post.created_at >= day_start, Post.created_at < day_end
).scalar() or 0
user_trend.append({"date": date_str, "count": u_count})
post_trend.append({"date": date_str, "count": p_count})
# 最近注册用户
recent_users = db.query(User).order_by(User.created_at.desc()).limit(5).all()
recent_users_data = [
{"id": u.id, "username": u.username, "email": u.email, "created_at": str(u.created_at)}
for u in recent_users
]
# 最近发布帖子
recent_posts = db.query(Post).order_by(Post.created_at.desc()).limit(5).all()
recent_posts_data = []
for p in recent_posts:
author = db.query(User).filter(User.id == p.user_id).first()
recent_posts_data.append({
"id": p.id, "title": p.title,
"author": author.username if author else "未知",
"created_at": str(p.created_at),
})
return {
"total_users": total_users,
"total_posts": total_posts,
"total_comments": total_comments,
"total_likes": total_likes,
"today_users": today_users,
"today_posts": today_posts,
"today_active": today_active,
"user_trend": user_trend,
"post_trend": post_trend,
"recent_users": recent_users_data,
"recent_posts": recent_posts_data,
}
@router.get("/users")
async def list_users(
search: str = "",
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""用户管理列表"""
query = db.query(User)
if search:
query = query.filter(User.username.contains(search))
total = query.count()
users = query.order_by(User.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
items = []
for u in users:
post_count = db.query(sa_func.count(Post.id)).filter(Post.user_id == u.id).scalar() or 0
comment_count = db.query(sa_func.count(Comment.id)).filter(Comment.user_id == u.id).scalar() or 0
items.append({
"id": u.id,
"username": u.username,
"email": u.email,
"avatar": u.avatar or "",
"is_admin": u.is_admin,
"is_banned": getattr(u, 'is_banned', False),
"is_approved": getattr(u, 'is_approved', True),
"post_count": post_count,
"comment_count": comment_count,
"created_at": str(u.created_at),
})
return {"items": items, "total": total, "page": page, "page_size": page_size}
@router.put("/users/{user_id}/toggle-admin")
async def toggle_admin(
user_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""切换用户管理员身份"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if user.id == admin.id:
raise HTTPException(status_code=400, detail="不能修改自己的管理员状态")
user.is_admin = not user.is_admin
db.commit()
return {"message": f"{'设为' if user.is_admin else '取消'}管理员", "is_admin": user.is_admin}
@router.put("/users/{user_id}/toggle-ban")
async def toggle_ban(
user_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""封禁/解封用户"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if user.id == admin.id:
raise HTTPException(status_code=400, detail="不能封禁自己")
if user.is_admin:
raise HTTPException(status_code=400, detail="不能封禁管理员")
user.is_banned = not user.is_banned
db.commit()
return {"message": f"{'封禁' if user.is_banned else '解封'}该用户", "is_banned": user.is_banned}
@router.put("/users/{user_id}/approve")
async def approve_user(
user_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""审核通过用户"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if getattr(user, 'is_approved', False):
raise HTTPException(status_code=400, detail="该用户已通过审核")
user.is_approved = True
db.commit()
return {"message": f"已审核通过用户:{user.username}", "is_approved": True}
@router.put("/users/{user_id}/reject")
async def reject_user(
user_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""拒绝/撤回用户审核"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if user.is_admin:
raise HTTPException(status_code=400, detail="不能拒绝管理员")
user.is_approved = False
db.commit()
return {"message": f"已拒绝用户:{user.username}", "is_approved": False}
@router.get("/posts")
async def list_posts(
search: str = "",
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""帖子管理列表"""
query = db.query(Post)
if search:
query = query.filter(Post.title.contains(search))
total = query.count()
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
items = []
for p in posts:
author = db.query(User).filter(User.id == p.user_id).first()
like_count = db.query(sa_func.count(Like.id)).filter(Like.post_id == p.id).scalar() or 0
comment_count = db.query(sa_func.count(Comment.id)).filter(Comment.post_id == p.id).scalar() or 0
items.append({
"id": p.id,
"title": p.title,
"author": author.username if author else "未知",
"author_id": p.user_id,
"category": p.category or "",
"is_public": p.is_public,
"like_count": like_count,
"comment_count": comment_count,
"view_count": p.view_count,
"created_at": str(p.created_at),
})
return {"items": items, "total": total, "page": page, "page_size": page_size}
@router.delete("/posts/{post_id}")
async def delete_post(
post_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""管理员删除帖子"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
# 删除关联数据
db.query(Comment).filter(Comment.post_id == post_id).delete()
db.query(Like).filter(Like.post_id == post_id).delete()
db.query(Collect).filter(Collect.post_id == post_id).delete()
db.delete(post)
db.commit()
return {"message": "删除成功"}
# ---------- 分类管理 ----------
@router.get("/categories")
async def list_categories(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取所有分类(含禁用的)"""
cats = db.query(Category).order_by(Category.sort_order, Category.id).all()
return [{"id": c.id, "name": c.name, "sort_order": c.sort_order, "is_active": c.is_active} for c in cats]
class CategoryCreate(BaseModel):
name: str
class CategoryUpdate(BaseModel):
name: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
@router.post("/categories")
async def create_category(
data: CategoryCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""新增分类"""
existing = db.query(Category).filter(Category.name == data.name).first()
if existing:
raise HTTPException(status_code=400, detail="分类名称已存在")
max_order = db.query(sa_func.max(Category.sort_order)).scalar() or 0
cat = Category(name=data.name, sort_order=max_order + 1)
db.add(cat)
db.commit()
db.refresh(cat)
return {"id": cat.id, "name": cat.name, "sort_order": cat.sort_order, "is_active": cat.is_active}
@router.put("/categories/{cat_id}")
async def update_category(
cat_id: int,
data: CategoryUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""修改分类"""
cat = db.query(Category).filter(Category.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
if data.name is not None:
dup = db.query(Category).filter(Category.name == data.name, Category.id != cat_id).first()
if dup:
raise HTTPException(status_code=400, detail="分类名称已存在")
cat.name = data.name
if data.sort_order is not None:
cat.sort_order = data.sort_order
if data.is_active is not None:
cat.is_active = data.is_active
db.commit()
return {"id": cat.id, "name": cat.name, "sort_order": cat.sort_order, "is_active": cat.is_active}
@router.delete("/categories/{cat_id}")
async def delete_category(
cat_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""删除分类"""
cat = db.query(Category).filter(Category.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
db.delete(cat)
db.commit()
return {"message": "删除成功"}
@router.put("/categories/reorder")
async def reorder_categories(
items: list,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""批量更新分类排序"""
for item in items:
cat = db.query(Category).filter(Category.id == item["id"]).first()
if cat:
cat.sort_order = item["sort_order"]
db.commit()
return {"message": "排序已更新"}
# ---------- 公开分类API无需管理员权限 ----------
@router.get("/public/categories")
async def get_public_categories(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取启用的分类列表(前台使用)"""
cats = db.query(Category).filter(Category.is_active == True).order_by(Category.sort_order, Category.id).all()
return [c.name for c in cats]

View File

@@ -0,0 +1,223 @@
"""AI智能排版路由 - 文本排版 + 自动配图"""
import json
import uuid
import httpx
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from database import get_db
from models.user import User
from models.ai_model import AIModelConfig
from routers.auth import get_current_user
from services.ai_service import ai_service
router = APIRouter(prefix="/api/ai", tags=["AI智能排版"])
FORMAT_SYSTEM_PROMPT = """你是一位专业的文章排版编辑。你的任务是将用户提供的文章内容重新排版为结构清晰、可读性强的 Markdown 格式。
【核心原则】绝对不能修改、删除、改写或替换原文的任何文字内容。原文的每一个字、每一句话都必须原样保留。你可以在合适的位置补充过渡语、小结或说明文字来增强可读性,但必须与原文明确区分,且不能改动原文已有的文字。
排版规则:
1. 分析文章结构,在合适位置添加标题层级(## 和 ###),标题文字从原文中提取
2. 将原文中的要点整理为列表(有序或无序),但列表内容必须是原文原句
3. 重要观点用引用块 > 包裹,引用内容必须是原文原句
4. 关键词和重要内容用 **加粗** 标记
5. 保留原文中所有图片链接(![...](...)格式),不要修改或删除
6. 保留原文中所有URL链接转为 [链接文字](url) 格式
7. 适当添加分隔线 --- 划分章节
8. 长段落拆分为短段落,提高可读性,但不能改变段落中的文字
9. 如果内容中有流程或步骤,用有序列表清晰展示
10. 不要添加原文中没有的文字、解释、总结或过渡语
同时,你需要分析文章内容,为文章建议 1-2 张配图。对于每张配图,在合适的位置插入占位符,格式为:
[AI_IMAGE: 图片描述prompt用英文写描述要生成的图片内容风格简洁专业]
注意:
- 占位符要插在文章逻辑合适的位置(如章节开头、流程说明旁边)
- prompt 用英文描述风格flat illustration, modern, professional, tech style
- 不要超过 2 个图片占位符
- 如果文章内容不适合配图(如纯代码、纯链接列表),可以不加占位符"""
class FormatRequest(BaseModel):
model_config = {"protected_namespaces": ()}
content: str
generate_images: bool = False # 是否生成配图(默认关闭)
model_config_id: Optional[int] = None # 指定排版用的文本模型
class FormatResponse(BaseModel):
formatted_content: str
images_generated: int = 0
def _get_image_model(db: Session):
"""从数据库查找已配置的图像生成模型Seedream endpoint"""
# 查找 task_type 包含 image 或 model_name 包含 seedream 的模型
model = db.query(AIModelConfig).filter(
AIModelConfig.is_enabled == True,
AIModelConfig.task_type == "image",
).first()
if model:
return {
"api_key": model.api_key,
"base_url": model.base_url or "https://ark.cn-beijing.volces.com/api/v3",
"model_id": model.model_id,
}
return None
async def _generate_image(api_key: str, base_url: str, model_id: str, prompt: str) -> Optional[bytes]:
"""调用火山方舟 Seedream 生成图片,返回图片二进制数据"""
url = f"{base_url.rstrip('/')}/images/generations"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model_id,
"prompt": prompt,
"response_format": "url",
"size": "1920x1080", # 满足 Seedream 5.0 最低像素要求
}
try:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code != 200:
print(f"图像生成失败: {resp.status_code} - {resp.text}")
return None
data = resp.json()
image_url = data.get("data", [{}])[0].get("url", "")
if not image_url:
return None
# 下载图片(因为 Seedream 返回的是临时链接)
img_resp = await client.get(image_url)
if img_resp.status_code == 200:
return img_resp.content
return None
except Exception as e:
print(f"图像生成异常: {e}")
return None
def _upload_to_cos(db: Session, image_data: bytes) -> Optional[str]:
"""将图片上传到 COS返回永久 URL"""
from models.system_config import SystemConfig
keys = ["cos_secret_id", "cos_secret_key", "cos_bucket", "cos_region", "cos_custom_domain"]
config = {}
for k in keys:
row = db.query(SystemConfig).filter(SystemConfig.key == k).first()
config[k] = row.value if row else ""
secret_id = config.get("cos_secret_id", "")
secret_key = config.get("cos_secret_key", "")
bucket = config.get("cos_bucket", "")
region = config.get("cos_region", "")
if not all([secret_id, secret_key, bucket, region]):
return None
try:
from qcloud_cos import CosConfig, CosS3Client
cos_config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
client = CosS3Client(cos_config)
date_prefix = datetime.now().strftime("%Y/%m")
filename = f"{uuid.uuid4().hex}.png"
object_key = f"images/{date_prefix}/{filename}"
client.put_object(
Bucket=bucket,
Body=image_data,
Key=object_key,
ContentType="image/png",
)
custom_domain = config.get("cos_custom_domain", "")
if custom_domain:
return f"https://{custom_domain}/{object_key}"
return f"https://{bucket}.cos.{region}.myqcloud.com/{object_key}"
except Exception as e:
print(f"COS上传失败: {e}")
return None
@router.post("/format", response_model=FormatResponse)
async def format_article(
req: FormatRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""AI智能排版格式化文章 + 自动生成配图"""
if not req.content.strip():
raise HTTPException(status_code=400, detail="文章内容不能为空")
# 第1步AI 排版文本
messages = [{"role": "user", "content": req.content}]
try:
formatted = await ai_service.chat(
task_type="reasoning",
messages=messages,
system_prompt=FORMAT_SYSTEM_PROMPT,
stream=False,
model_config_id=req.model_config_id,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"AI排版失败: {str(e)}")
if not isinstance(formatted, str):
raise HTTPException(status_code=500, detail="AI排版返回格式异常")
# 过滤掉思考过程DeepSeek Reasoner 会输出 <think>...</think>
import re as _re
formatted = _re.sub(r'<think>[\s\S]*?</think>\s*', '', formatted).strip()
# 也过滤 <details> 格式的思考过程
formatted = _re.sub(r'<details>[\s\S]*?</details>\s*', '', formatted).strip()
images_generated = 0
# 第2步生成配图如果启用
if req.generate_images:
import re
placeholders = re.findall(r'\[AI_IMAGE:\s*(.+?)\]', formatted)
if placeholders:
image_model = _get_image_model(db)
if image_model:
for prompt in placeholders:
image_data = await _generate_image(
image_model["api_key"],
image_model["base_url"],
image_model["model_id"],
prompt.strip(),
)
if image_data:
# 上传到 COS
cos_url = _upload_to_cos(db, image_data)
if cos_url:
formatted = formatted.replace(
f"[AI_IMAGE: {prompt}]",
f"![{prompt.strip()[:50]}]({cos_url})",
1,
)
images_generated += 1
continue
# 生成或上传失败,移除占位符
formatted = formatted.replace(f"[AI_IMAGE: {prompt}]", "", 1)
else:
# 没有配置图像模型,清理所有占位符
formatted = re.sub(r'\[AI_IMAGE:\s*.+?\]', '', formatted)
# 清理可能残留的占位符
import re
formatted = re.sub(r'\[AI_IMAGE:\s*.+?\]', '', formatted)
# 清理多余空行
formatted = re.sub(r'\n{3,}', '\n\n', formatted).strip()
return FormatResponse(
formatted_content=formatted,
images_generated=images_generated,
)

View File

@@ -0,0 +1,285 @@
"""AI模型管理路由"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from database import get_db
from models.ai_model import AIModelConfig
from models.user import User
from schemas.ai_model import AIModelCreate, AIModelUpdate, AIModelResponse, ProviderInfo
from routers.auth import get_admin_user, get_current_user
router = APIRouter(prefix="/api/admin/models", tags=["AI模型管理"])
# 公开路由(登录用户可用,用于前台 AI 工具获取可选模型)
public_router = APIRouter(prefix="/api/models", tags=["AI模型公开"])
# 预置的服务商和模型信息
PROVIDER_PRESETS = [
{
"provider": "deepseek",
"name": "DeepSeek",
"default_base_url": "https://api.deepseek.com",
"models": [
{"model_id": "deepseek-chat", "name": "DeepSeek-V3.2", "task_types": ["lightweight", "knowledge_base"], "description": "DeepSeek-V3.2 非思考模式,性价比极高"},
{"model_id": "deepseek-reasoner", "name": "DeepSeek-V3.2 思考", "task_types": ["reasoning", "knowledge_base"], "description": "DeepSeek-V3.2 思考模式,带推理链输出"},
]
},
{
"provider": "openai",
"name": "OpenAI",
"default_base_url": "https://api.openai.com/v1",
"models": [
{"model_id": "gpt-4o", "name": "GPT-4o", "task_types": ["multimodal", "reasoning"], "description": "多模态旗舰模型"},
{"model_id": "gpt-4o-mini", "name": "GPT-4o Mini", "task_types": ["lightweight"], "description": "轻量高效模型"},
{"model_id": "o3-mini", "name": "o3-mini", "task_types": ["reasoning"], "description": "推理增强模型"},
{"model_id": "text-embedding-3-large", "name": "Embedding Large", "task_types": ["embedding"], "description": "高维度文本嵌入模型"},
]
},
{
"provider": "anthropic",
"name": "Anthropic",
"default_base_url": "https://api.anthropic.com",
"models": [
{"model_id": "claude-sonnet-4-20250514", "name": "Claude Sonnet 4", "task_types": ["reasoning"], "description": "Claude最新推理模型"},
{"model_id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "task_types": ["lightweight"], "description": "快速轻量模型"},
]
},
{
"provider": "google",
"name": "Google Gemini",
"default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
"models": [
{"model_id": "gemini-2.5-pro-preview-06-05", "name": "Gemini 2.5 Pro", "task_types": ["multimodal", "reasoning"], "description": "多模态能力最强"},
{"model_id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash", "task_types": ["lightweight", "multimodal"], "description": "快速多模态模型"},
]
},
{
"provider": "ark",
"name": "火山方舟(豆包)",
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
"models": [
{"model_id": "ep-20260411180700-z6nll", "name": "Doubao-Seed-2.0-pro", "task_types": ["reasoning", "lightweight", "knowledge_base"], "description": "豆包旗舰模型,支持联网搜索"},
{"model_id": "doubao-seedream-5-0-260128", "name": "Seedream 5.0 (图像生成)", "task_types": ["image"], "description": "豆包图像生成模型,支持文生图"},
]
},
]
TASK_TYPE_LABELS = {
"multimodal": "多模态(图片/草图理解)",
"reasoning": "推理分析(需求解读/架构分析)",
"lightweight": "轻量任务(分类/标签)",
"knowledge_base": "知识库分析(文档理解/问答)",
"embedding": "向量嵌入",
"image": "图像生成AI配图/文生图)",
}
def _mask_api_key(key: str) -> str:
"""API Key脱敏"""
if not key or len(key) < 8:
return "****" if key else ""
return key[:4] + "*" * (len(key) - 8) + key[-4:]
def _to_response(model: AIModelConfig) -> dict:
"""转换为响应格式"""
return {
"id": model.id,
"provider": model.provider,
"provider_name": model.provider_name,
"model_id": model.model_id,
"model_name": model.model_name,
"api_key_masked": _mask_api_key(model.api_key),
"base_url": model.base_url,
"task_type": model.task_type,
"is_enabled": model.is_enabled,
"is_default": model.is_default,
"web_search_enabled": model.web_search_enabled,
"description": model.description,
"created_at": model.created_at,
"updated_at": model.updated_at,
}
@router.get("/presets", response_model=List[ProviderInfo])
async def get_provider_presets():
"""获取预置的服务商和模型列表"""
return PROVIDER_PRESETS
@router.get("/task-types")
async def get_task_types():
"""获取任务类型列表"""
return TASK_TYPE_LABELS
@router.get("", response_model=List[AIModelResponse])
async def list_models(provider: str = None, task_type: str = None, db: Session = Depends(get_db)):
"""获取所有已配置的模型"""
query = db.query(AIModelConfig)
if provider:
query = query.filter(AIModelConfig.provider == provider)
if task_type:
query = query.filter(AIModelConfig.task_type == task_type)
models = query.order_by(AIModelConfig.provider, AIModelConfig.created_at).all()
return [_to_response(m) for m in models]
@router.post("", response_model=AIModelResponse)
async def create_model(data: AIModelCreate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""添加模型配置"""
model = AIModelConfig(
provider=data.provider,
provider_name=data.provider_name,
model_id=data.model_id,
model_name=data.model_name,
api_key=data.api_key,
base_url=data.base_url,
task_type=data.task_type,
is_enabled=data.is_enabled,
is_default=data.is_default,
web_search_enabled=data.web_search_enabled,
description=data.description,
)
# 如果设为默认,取消同任务类型的其他默认
if data.is_default and data.task_type:
db.query(AIModelConfig).filter(
AIModelConfig.task_type == data.task_type,
AIModelConfig.is_default == True
).update({"is_default": False})
db.add(model)
db.commit()
db.refresh(model)
return _to_response(model)
@router.put("/{model_id}", response_model=AIModelResponse)
async def update_model(model_id: int, data: AIModelUpdate, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""更新模型配置"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="模型配置不存在")
update_data = data.dict(exclude_unset=True)
# 如果API Key为空字符串表示不修改
if "api_key" in update_data and update_data["api_key"] == "":
del update_data["api_key"]
# 如果设为默认,取消同任务类型的其他默认
if update_data.get("is_default") and (update_data.get("task_type") or model.task_type):
task = update_data.get("task_type", model.task_type)
db.query(AIModelConfig).filter(
AIModelConfig.task_type == task,
AIModelConfig.is_default == True,
AIModelConfig.id != model_id
).update({"is_default": False})
for key, value in update_data.items():
setattr(model, key, value)
db.commit()
db.refresh(model)
return _to_response(model)
@router.delete("/{model_id}")
async def delete_model(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""删除模型配置"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="模型配置不存在")
db.delete(model)
db.commit()
return {"message": "删除成功"}
@router.post("/init-defaults")
async def init_default_models(db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""初始化默认模型配置(仅当数据库为空时)"""
count = db.query(AIModelConfig).count()
if count > 0:
return {"message": f"已有 {count} 条配置,跳过初始化", "count": count}
defaults = [
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-chat",
model_name="DeepSeek-V3", task_type="reasoning", is_default=True, is_enabled=True,
base_url="https://api.deepseek.com/v1", description="DeepSeek最新对话模型性价比极高"),
AIModelConfig(provider="deepseek", provider_name="DeepSeek", model_id="deepseek-reasoner",
model_name="DeepSeek-R1", task_type="", is_enabled=True,
base_url="https://api.deepseek.com/v1", description="深度推理模型,适合复杂逻辑分析"),
AIModelConfig(provider="openai", provider_name="OpenAI", model_id="gpt-4o-mini",
model_name="GPT-4o Mini", task_type="lightweight", is_default=True, is_enabled=True,
description="轻量高效模型"),
AIModelConfig(provider="google", provider_name="Google Gemini", model_id="gemini-2.5-pro-preview-06-05",
model_name="Gemini 2.5 Pro", task_type="multimodal", is_default=True, is_enabled=True,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
description="多模态能力最强"),
]
db.add_all(defaults)
db.commit()
return {"message": f"已初始化 {len(defaults)} 条默认配置", "count": len(defaults)}
@router.post("/{model_id}/test")
async def test_model_connection(model_id: int, db: Session = Depends(get_db), admin: User = Depends(get_admin_user)):
"""测试模型连接是否正常"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="模型配置不存在")
if not model.api_key:
return {"success": False, "message": "未配置 API Key"}
try:
import httpx
headers = {"Authorization": f"Bearer {model.api_key}", "Content-Type": "application/json"}
base_url = model.base_url or f"https://api.{model.provider}.com/v1"
if model.provider == "anthropic":
headers = {"x-api-key": model.api_key, "anthropic-version": "2023-06-01", "Content-Type": "application/json"}
url = f"{base_url}/v1/messages"
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
else:
url = f"{base_url}/chat/completions"
payload = {"model": model.model_id, "max_tokens": 10, "messages": [{"role": "user", "content": "hi"}]}
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code == 200:
return {"success": True, "message": "连接成功"}
else:
return {"success": False, "message": f"HTTP {resp.status_code}: {resp.text[:200]}"}
except Exception as e:
return {"success": False, "message": f"连接失败: {str(e)}"}
# ===== 公开接口 =====
@public_router.get("/available")
async def get_available_models(
task_type: str = "",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取可用模型列表(登录用户可调用,不返回 API Key"""
query = db.query(AIModelConfig).filter(
AIModelConfig.is_enabled == True,
AIModelConfig.api_key != "",
)
if task_type:
query = query.filter(AIModelConfig.task_type == task_type)
models = query.order_by(AIModelConfig.is_default.desc(), AIModelConfig.created_at).all()
return [
{
"id": m.id,
"provider": m.provider,
"provider_name": m.provider_name,
"model_id": m.model_id,
"model_name": m.model_name,
"task_type": m.task_type,
"is_default": m.is_default,
"web_search_enabled": m.web_search_enabled,
"web_search_count": m.web_search_count or 5,
"description": m.description,
}
for m in models
]

View File

@@ -0,0 +1,296 @@
"""架构选型助手路由"""
import json
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from models.user import User
from models.conversation import Conversation, Message
from schemas.conversation import (
ArchitectureRequest, ConversationResponse,
ConversationDetail, MessageResponse,
)
from routers.auth import get_current_user
from services.ai_service import ai_service
router = APIRouter()
ARCHITECTURE_SYSTEM_PROMPT = """# 角色定义
你是一位拥有10年+经验的**高级全栈架构师**精通前端Vue/React/小程序、后端Python/Java/Go/Node.js、数据库MySQL/PostgreSQL/MongoDB/Redis、云服务与DevOps。你做过大量从0到1的项目对技术选型的利弊、不同规模系统的架构模式了如指掌。
你的工作:接收用户提供的**已确认的功能需求**(可能来自需求助手的输出),给出完整的、可直接落地开发的技术方案。
> ⚠️ 本助手专注于**技术选型与架构设计**。如果用户发来的是原始甲方需求,建议先到「需求理解助手」进行需求分析。
# 核心理念
- **没有最好的技术,只有最合适的技术**:选型必须匹配项目规模、团队能力和预算
- **方案要能落地写代码**:不出纯理论的架构图,给的方案要具体到程序员能直接开干
- **过度设计是大忌**:小项目用微服务是灾难,要敢于推荐简单方案
# 分析框架
## 第一步:项目画像评估
- 项目规模:小型(个人/小团队)/ 中型(创业公司)/ 大型(企业级)
- 预期用户量和并发量
- 团队技术栈偏好(如果用户有提及)
- 预算和时间约束
## 第二步:技术选型(带对比和理由)
针对每一层给出推荐方案和备选方案:
- 前端框架 + UI组件库
- 后端语言 + Web框架
- 数据库(主库 + 缓存)
- 文件存储方案
- 部署方案
- 第三方服务(如果需要)
## 第三步:系统架构设计
- 整体架构图Mermaid语法
- 核心数据模型ER关系、表结构
- 关键接口设计RESTful API清单
- 目录结构规划
## 第四步:技术难点与避坑指南
- 基于实战经验,针对该项目的具体技术难点给出解决方案
- 常见踩坑点和规避策略
- 安全注意事项XSS、CSRF、SQL注入、越权等
## 第五步:开发路线图
- MVP版本应包含哪些功能
- 迭代计划建议
- 工期评估(按模块拆分前后端工时)
# 输出规范
严格使用以下 Markdown 结构输出:
---
## 🎯 项目画像
| 维度 | 评估 |
|------|------|
| 项目规模 | xxx |
| 预期用户量 | xxx |
| 推荐架构模式 | 单体/前后端分离/微服务 |
## 🏗️ 技术选型
| 层级 | 推荐方案 | 备选方案 | 选型理由 |
|------|---------|---------|---------|
| 前端框架 | xxx | xxx | xxx |
| UI组件库 | xxx | xxx | xxx |
| 后端框架 | xxx | xxx | xxx |
| 数据库 | xxx | xxx | xxx |
| 缓存 | xxx | xxx | xxx |
| 部署 | xxx | xxx | xxx |
## 📐 系统架构图
```mermaid
graph TB
A[前端] --> B[API网关]
B --> C[后端服务]
C --> D[数据库]
```
## 🗄️ 核心数据模型
```sql
-- 表名: xxx
-- 说明: xxx
CREATE TABLE xxx (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
xxx VARCHAR(255) NOT NULL COMMENT 'xxx',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- 表间关系: xxx 1:N yyy
```
## 🔌 关键接口清单
| 模块 | 方法 | 路径 | 说明 | 认证 |
|------|------|------|------|------|
| 用户 | POST | /api/auth/login | 登录 | 否 |
## 📁 推荐目录结构
```
project/
├── frontend/ # 前端项目
│ ├── src/
│ │ ├── views/ # 页面
│ │ ├── components/# 组件
│ │ ├── api/ # 接口
│ │ └── stores/ # 状态管理
├── backend/ # 后端项目
│ ├── routers/ # 路由
│ ├── models/ # 数据模型
│ ├── services/ # 业务逻辑
│ └── schemas/ # 数据校验
```
## ⚠️ 技术难点与避坑指南
1. **【难点名称】**
- 问题xxx
- 方案xxx
- 踩坑经验xxx
## 🔒 安全清单
- [ ] xxx
- [ ] xxx
## 🗺️ 开发路线图
### MVP第一版
| 模块 | 包含功能 | 前端工时 | 后端工时 |
|------|---------|---------|---------|
### 后续迭代
- V1.1: xxx
- V1.2: xxx
---
# 交互原则
1. **选型必须带理由**:不说"推荐用Vue",要说"推荐Vue 3因为xxx如果团队熟悉React也可以用"
2. **方案要分档**:针对不同预算/规模给出不同方案(如"预算充足用云服务省钱可以用VPS"
3. **代码要能跑**给出的SQL、目录结构、接口设计都要是可以直接使用的
4. **架构图用Mermaid**:使用 ```mermaid 代码块,只用基础语法,不加样式
5. **敢于说"不需要"**如果项目不需要Redis/微服务/消息队列,要直说,不为了显得高级而过度设计
6. **持续深化**:用户追问某个模块时,在已有方案基础上深入展开,保持一致性"""
@router.get("/conversations", response_model=List[ConversationResponse])
def get_conversations(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取架构对话列表"""
conversations = (
db.query(Conversation)
.filter(Conversation.user_id == current_user.id, Conversation.type == "architecture")
.order_by(Conversation.updated_at.desc())
.all()
)
return [ConversationResponse.model_validate(c) for c in conversations]
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
def get_conversation_detail(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取对话详情"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
messages = (
db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.all()
)
result = ConversationDetail.model_validate(conv)
result.messages = [MessageResponse.model_validate(m) for m in messages]
return result
@router.post("/recommend")
async def recommend_architecture(
request: ArchitectureRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""架构推荐 - 流式输出"""
# 创建或获取对话
if request.conversation_id:
conv = db.query(Conversation).filter(
Conversation.id == request.conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
else:
conv = Conversation(
user_id=current_user.id,
title=request.content[:50] if request.content else "新架构咨询",
type="architecture",
)
db.add(conv)
db.commit()
db.refresh(conv)
# 保存用户消息
user_msg = Message(
conversation_id=conv.id,
role="user",
content=request.content,
)
db.add(user_msg)
db.commit()
# 构建历史消息
history_msgs = (
db.query(Message)
.filter(Message.conversation_id == conv.id)
.order_by(Message.created_at.asc())
.all()
)
messages = [{"role": msg.role, "content": msg.content} for msg in history_msgs]
# 流式调用AI
async def generate():
full_response = ""
try:
result = await ai_service.chat(
task_type="reasoning",
messages=messages,
system_prompt=ARCHITECTURE_SYSTEM_PROMPT,
stream=True,
model_config_id=request.model_config_id,
)
if isinstance(result, str):
full_response = result
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
else:
async for chunk in result:
full_response += chunk
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
except Exception as e:
error_msg = f"AI调用出错: {str(e)}"
full_response = error_msg
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
# 保存AI回复
ai_msg = Message(
conversation_id=conv.id,
role="assistant",
content=full_response,
)
db.add(ai_msg)
db.commit()
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@router.delete("/conversations/{conversation_id}")
def delete_conversation(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除对话"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
db.delete(conv)
db.commit()
return {"message": "删除成功"}

136
backend/routers/auth.py Normal file
View File

@@ -0,0 +1,136 @@
"""认证路由"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from jose import JWTError, jwt
from datetime import datetime, timedelta
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from database import get_db
from models.user import User
from schemas.user import UserRegister, UserLogin, UserResponse, TokenResponse, UserUpdate
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
router = APIRouter()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
def create_access_token(data: dict) -> str:
"""创建JWT Token"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db),
) -> User:
"""从Token获取当前用户"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="无效的认证凭据")
user_id = int(user_id)
except JWTError:
raise HTTPException(status_code=401, detail="无效的认证凭据")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(status_code=401, detail="用户不存在")
return user
def get_admin_user(current_user: User = Depends(get_current_user)) -> User:
"""要求当前用户是管理员"""
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="需要管理员权限")
return current_user
@router.post("/register")
def register(data: UserRegister, db: Session = Depends(get_db)):
"""用户注册(需管理员审核后才可使用)"""
# 检查用户名是否已存在
if db.query(User).filter(User.username == data.username).first():
raise HTTPException(status_code=400, detail="用户名已存在")
if db.query(User).filter(User.email == data.email).first():
raise HTTPException(status_code=400, detail="邮箱已被注册")
# 创建用户is_approved 默认 False等待审核
user = User(
username=data.username,
email=data.email,
password_hash=pwd_context.hash(data.password),
)
db.add(user)
db.commit()
db.refresh(user)
return {"message": "注册成功,请等待管理员审核通过后即可登录使用"}
@router.post("/login", response_model=TokenResponse)
def login(data: UserLogin, db: Session = Depends(get_db)):
"""用户登录"""
user = db.query(User).filter(User.username == data.username).first()
if not user or not pwd_context.verify(data.password, user.password_hash):
raise HTTPException(status_code=401, detail="用户名或密码错误")
if getattr(user, 'is_banned', False):
raise HTTPException(status_code=403, detail="账号已被封禁,请联系管理员")
if not getattr(user, 'is_approved', False):
raise HTTPException(status_code=403, detail="账号尚未通过审核,请耐心等待管理员审核")
token = create_access_token({"sub": str(user.id)})
return TokenResponse(
access_token=token,
user=UserResponse.model_validate(user),
)
@router.get("/me", response_model=UserResponse)
def get_me(current_user: User = Depends(get_current_user)):
"""获取当前用户信息"""
return UserResponse.model_validate(current_user)
@router.put("/profile", response_model=UserResponse)
def update_profile(
data: UserUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""更新个人资料"""
# 修改用户名
if data.username and data.username != current_user.username:
if db.query(User).filter(User.username == data.username, User.id != current_user.id).first():
raise HTTPException(status_code=400, detail="用户名已存在")
current_user.username = data.username
# 修改邮箱
if data.email and data.email != current_user.email:
if db.query(User).filter(User.email == data.email, User.id != current_user.id).first():
raise HTTPException(status_code=400, detail="邮箱已被使用")
current_user.email = data.email
# 修改头像
if data.avatar is not None:
current_user.avatar = data.avatar
# 修改密码
if data.new_password:
if not data.old_password:
raise HTTPException(status_code=400, detail="请输入当前密码")
if not pwd_context.verify(data.old_password, current_user.password_hash):
raise HTTPException(status_code=400, detail="当前密码错误")
current_user.password_hash = pwd_context.hash(data.new_password)
db.commit()
db.refresh(current_user)
return UserResponse.model_validate(current_user)

View File

@@ -0,0 +1,118 @@
"""网站收藏路由"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from database import get_db
from models.bookmark import BookmarkSite
from models.user import User
from schemas.bookmark import BookmarkCreate, BookmarkUpdate, BookmarkResponse, ReorderRequest
from routers.auth import get_current_user
router = APIRouter()
@router.get("", response_model=List[BookmarkResponse])
def get_bookmarks(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户的收藏网站列表"""
bookmarks = (
db.query(BookmarkSite)
.filter(BookmarkSite.user_id == current_user.id)
.order_by(BookmarkSite.sort_order, BookmarkSite.created_at)
.all()
)
return [BookmarkResponse.model_validate(b) for b in bookmarks]
@router.post("", response_model=BookmarkResponse)
def create_bookmark(
data: BookmarkCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""添加收藏网站"""
# 获取当前最大排序值
max_order = (
db.query(BookmarkSite.sort_order)
.filter(BookmarkSite.user_id == current_user.id)
.order_by(BookmarkSite.sort_order.desc())
.first()
)
next_order = (max_order[0] + 1) if max_order else 0
bookmark = BookmarkSite(
user_id=current_user.id,
name=data.name,
url=data.url,
icon=data.icon,
sort_order=next_order,
)
db.add(bookmark)
db.commit()
db.refresh(bookmark)
return BookmarkResponse.model_validate(bookmark)
@router.put("/reorder")
def reorder_bookmarks(
data: ReorderRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""批量更新排序"""
for item in data.items:
db.query(BookmarkSite).filter(
BookmarkSite.id == item.id,
BookmarkSite.user_id == current_user.id,
).update({"sort_order": item.sort_order})
db.commit()
return {"message": "排序已更新"}
@router.put("/{bookmark_id}", response_model=BookmarkResponse)
def update_bookmark(
bookmark_id: int,
data: BookmarkUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""编辑收藏网站"""
bookmark = db.query(BookmarkSite).filter(
BookmarkSite.id == bookmark_id,
BookmarkSite.user_id == current_user.id,
).first()
if not bookmark:
raise HTTPException(status_code=404, detail="收藏不存在")
if data.name is not None:
bookmark.name = data.name
if data.url is not None:
bookmark.url = data.url
if data.icon is not None:
bookmark.icon = data.icon
if data.sort_order is not None:
bookmark.sort_order = data.sort_order
db.commit()
db.refresh(bookmark)
return BookmarkResponse.model_validate(bookmark)
@router.delete("/{bookmark_id}")
def delete_bookmark(
bookmark_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除收藏网站"""
bookmark = db.query(BookmarkSite).filter(
BookmarkSite.id == bookmark_id,
BookmarkSite.user_id == current_user.id,
).first()
if not bookmark:
raise HTTPException(status_code=404, detail="收藏不存在")
db.delete(bookmark)
db.commit()
return {"message": "删除成功"}

View File

@@ -0,0 +1,633 @@
"""团队知识库路由"""
import json
import hashlib
from typing import Optional, List
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Header, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import func as sa_func, or_
from pydantic import BaseModel
from database import get_db
from config import SECRET_KEY
from models.user import User
from models.post import Post
from models.system_config import SystemConfig
from models.knowledge_base import KbCategory, KbItem, KbAccessLog
from models.attachment import Attachment
from routers.auth import get_current_user, get_admin_user
from services.ai_service import ai_service
router = APIRouter()
# ========== 密码机制(复用 API Hub 方案) ==========
def _get_kb_password(db: Session) -> str:
cfg = db.query(SystemConfig).filter(SystemConfig.key == "kb_password").first()
return cfg.value if cfg else ""
def _password_version(db: Session) -> str:
"""返回密码哈希前8位作为版本标识密码变更后旧token自动失效"""
pwd = _get_kb_password(db)
return pwd[:8] if pwd else "none"
def _hash_password(pwd: str) -> str:
return hashlib.sha256(pwd.encode()).hexdigest()
def _create_kb_token(user_id: int, pwd_ver: str = "none") -> str:
from jose import jwt
exp = datetime.utcnow() + timedelta(hours=2)
return jwt.encode({"sub": str(user_id), "kb": True, "pv": pwd_ver, "exp": exp}, SECRET_KEY, algorithm="HS256")
def verify_kb_access(
x_kb_token: Optional[str] = Header(None),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""验证用户登录 + 知识库访问令牌"""
if not x_kb_token:
raise HTTPException(status_code=403, detail="需要知识库访问权限,请先验证密码")
from jose import jwt, JWTError
try:
payload = jwt.decode(x_kb_token, SECRET_KEY, algorithms=["HS256"])
if not payload.get("kb"):
raise HTTPException(status_code=403, detail="无效的知识库令牌")
# 检查密码版本是否匹配
token_pv = payload.get("pv", "")
current_pv = _password_version(db)
if token_pv != current_pv:
raise HTTPException(status_code=403, detail="密码已变更,请重新验证")
except JWTError:
raise HTTPException(status_code=403, detail="知识库令牌已过期,请重新验证密码")
return current_user
# ========== Schemas ==========
class KbCategoryCreate(BaseModel):
name: str
icon: str = ""
class KbCategoryUpdate(BaseModel):
name: Optional[str] = None
icon: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class KbItemAdd(BaseModel):
post_ids: List[int]
category_id: Optional[int] = None
class KbItemUpdate(BaseModel):
category_id: Optional[int] = None
title: Optional[str] = None
summary: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class KbAiChatRequest(BaseModel):
question: str
class PasswordBody(BaseModel):
password: str
# ========== 密码认证接口 ==========
@router.post("/auth")
def kb_auth(
body: PasswordBody,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""知识库密码验证"""
stored = _get_kb_password(db)
pv = _password_version(db)
if not stored:
# 未设置密码,直接放行
token = _create_kb_token(current_user.id, pv)
return {"token": token}
if _hash_password(body.password) != stored:
raise HTTPException(status_code=401, detail="密码错误")
token = _create_kb_token(current_user.id, pv)
return {"token": token}
@router.get("/check-password")
def kb_check_password(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""检查知识库是否需要密码"""
stored = _get_kb_password(db)
return {"has_password": bool(stored)}
# ========== 公开接口(需登录 + kb_token ==========
@router.get("/categories")
def get_categories(
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库分类列表"""
cats = db.query(KbCategory).filter(KbCategory.is_active == True).order_by(KbCategory.sort_order, KbCategory.id).all()
result = []
for c in cats:
count = db.query(sa_func.count(KbItem.id)).filter(KbItem.category_id == c.id, KbItem.is_active == True).scalar() or 0
result.append({"id": c.id, "name": c.name, "icon": c.icon, "count": count})
return result
@router.get("/items")
def get_items(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
category_id: Optional[int] = None,
keyword: Optional[str] = None,
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库条目列表"""
query = db.query(KbItem).filter(KbItem.is_active == True)
if category_id:
query = query.filter(KbItem.category_id == category_id)
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
total = query.count()
items = query.order_by(KbItem.sort_order, KbItem.created_at.desc()).offset((page - 1) * size).limit(size).all()
result = []
for item in items:
post = db.query(Post).filter(Post.id == item.post_id).first()
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
author = db.query(User).filter(User.id == item.added_by).first() if item.added_by else None
result.append({
"id": item.id,
"title": item.title,
"summary": item.summary or (post.content[:150] if post else ""),
"category_id": item.category_id,
"category_name": cat.name if cat else "",
"post_id": item.post_id,
"post_author": post.user_id if post else None,
"added_by_name": author.username if author else "",
"created_at": item.created_at.isoformat() if item.created_at else None,
})
# 记录访问日志
if keyword:
db.add(KbAccessLog(user_id=user.id, action="search", query=keyword))
db.commit()
return {"items": result, "total": total, "page": page, "size": size}
@router.get("/items/{item_id}")
def get_item_detail(
item_id: int,
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库条目详情(含帖子完整内容)"""
item = db.query(KbItem).filter(KbItem.id == item_id, KbItem.is_active == True).first()
if not item:
raise HTTPException(status_code=404, detail="条目不存在")
post = db.query(Post).filter(Post.id == item.post_id).first()
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
post_author = db.query(User).filter(User.id == post.user_id).first() if post else None
# 记录访问
db.add(KbAccessLog(user_id=user.id, action="view", query=str(item_id)))
db.commit()
return {
"id": item.id,
"title": item.title,
"summary": item.summary,
"category_id": item.category_id,
"category_name": cat.name if cat else "",
"post_id": item.post_id,
"post_title": post.title if post else "",
"post_content": post.content if post else "",
"post_author": {"id": post_author.id, "username": post_author.username} if post_author else None,
"post_tags": post.tags if post else "",
"post_category": post.category if post else "",
"created_at": item.created_at.isoformat() if item.created_at else None,
"attachments": [
{"id": a.id, "filename": a.filename, "url": a.url, "file_size": a.file_size, "file_type": a.file_type}
for a in db.query(Attachment).filter(Attachment.post_id == item.post_id).order_by(Attachment.created_at.asc()).all()
] if post else [],
}
@router.get("/stats")
def get_kb_stats(
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""获取知识库统计"""
total_items = db.query(sa_func.count(KbItem.id)).filter(KbItem.is_active == True).scalar() or 0
total_categories = db.query(sa_func.count(KbCategory.id)).filter(KbCategory.is_active == True).scalar() or 0
total_views = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "view").scalar() or 0
total_searches = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "search").scalar() or 0
total_ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "ai_chat").scalar() or 0
return {
"total_items": total_items,
"total_categories": total_categories,
"total_views": total_views,
"total_searches": total_searches,
"total_ai_chats": total_ai_chats,
}
# ========== AI 智能问答 ==========
KB_AI_SYSTEM_PROMPT = """你是一个团队知识库的AI助手。你的任务是根据知识库中的内容回答用户的问题。
规则:
1. 只基于提供的知识库内容回答问题,不编造信息
2. 如果知识库中没有相关内容,诚实告知用户
3. 回答时引用来源文章的标题
4. 使用 Markdown 格式组织回答
5. 回答要简洁、专业、有条理"""
@router.post("/ai-chat")
async def kb_ai_chat(
request: KbAiChatRequest,
user: User = Depends(verify_kb_access),
db: Session = Depends(get_db),
):
"""AI 智能问答SSE 流式)"""
question = request.question.strip()
if not question:
raise HTTPException(status_code=400, detail="请输入问题")
# 从知识库中检索相关内容作为上下文
kw = f"%{question[:50]}%"
# 简单关键词匹配(后续可升级为向量搜索)
related_items = (
db.query(KbItem)
.filter(KbItem.is_active == True)
.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
.limit(5)
.all()
)
# 如果关键词搜索结果不足,补充最新的条目
if len(related_items) < 3:
existing_ids = [item.id for item in related_items]
extra = (
db.query(KbItem)
.filter(KbItem.is_active == True, ~KbItem.id.in_(existing_ids) if existing_ids else True)
.order_by(KbItem.created_at.desc())
.limit(10 - len(related_items))
.all()
)
related_items.extend(extra)
# 构建知识库上下文
context_parts = []
for item in related_items:
post = db.query(Post).filter(Post.id == item.post_id).first()
if post:
content = post.content[:2000] # 限制每篇长度
context_parts.append(f"### {item.title}\n{content}")
kb_context = "\n\n---\n\n".join(context_parts) if context_parts else "知识库暂无相关内容。"
messages = [
{"role": "user", "content": f"以下是知识库中的相关内容:\n\n{kb_context}\n\n---\n\n用户问题:{question}"}
]
# 记录日志
db.add(KbAccessLog(user_id=user.id, action="ai_chat", query=question))
db.commit()
# 流式调用AI
async def generate():
full_response = ""
try:
result = await ai_service.chat(
task_type="reasoning",
messages=messages,
system_prompt=KB_AI_SYSTEM_PROMPT,
stream=True,
)
if isinstance(result, str):
full_response = result
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
else:
async for chunk in result:
full_response += chunk
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
except Exception as e:
error_msg = f"AI调用出错: {str(e)}"
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
yield f"data: {json.dumps({'content': '', 'done': True})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# ========== 管理员接口 ==========
# --- 密码管理 ---
@router.put("/admin/password")
def set_kb_password(
body: PasswordBody,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""设置/修改知识库密码"""
hashed = _hash_password(body.password) if body.password else ""
row = db.query(SystemConfig).filter(SystemConfig.key == "kb_password").first()
if row:
row.value = hashed
else:
db.add(SystemConfig(key="kb_password", value=hashed, description="知识库访问密码"))
db.commit()
return {"message": "密码已更新"}
@router.get("/admin/password-status")
def get_kb_password_status(
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取知识库密码状态"""
stored = _get_kb_password(db)
return {"has_password": bool(stored)}
# --- 分类管理 ---
@router.get("/admin/categories")
def admin_list_categories(
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取所有知识库分类(含禁用)"""
cats = db.query(KbCategory).order_by(KbCategory.sort_order, KbCategory.id).all()
return [
{
"id": c.id, "name": c.name, "icon": c.icon,
"sort_order": c.sort_order, "is_active": c.is_active,
"item_count": db.query(sa_func.count(KbItem.id)).filter(KbItem.category_id == c.id).scalar() or 0,
"created_at": c.created_at.isoformat() if c.created_at else None,
}
for c in cats
]
@router.post("/admin/categories")
def admin_create_category(
data: KbCategoryCreate,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""创建知识库分类"""
exists = db.query(KbCategory).filter(KbCategory.name == data.name).first()
if exists:
raise HTTPException(status_code=400, detail="分类名称已存在")
cat = KbCategory(name=data.name, icon=data.icon)
db.add(cat)
db.commit()
db.refresh(cat)
return {"id": cat.id, "name": cat.name, "message": "创建成功"}
@router.put("/admin/categories/{cat_id}")
def admin_update_category(
cat_id: int,
data: KbCategoryUpdate,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""更新知识库分类"""
cat = db.query(KbCategory).filter(KbCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
updates = data.dict(exclude_none=True)
for key, value in updates.items():
setattr(cat, key, value)
db.commit()
return {"message": "更新成功"}
@router.delete("/admin/categories/{cat_id}")
def admin_delete_category(
cat_id: int,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""删除知识库分类"""
cat = db.query(KbCategory).filter(KbCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
# 将该分类下的条目设为未分类
db.query(KbItem).filter(KbItem.category_id == cat_id).update({"category_id": None})
db.delete(cat)
db.commit()
return {"message": "删除成功"}
# --- 条目管理 ---
@router.get("/admin/items")
def admin_list_items(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
category_id: Optional[int] = None,
keyword: Optional[str] = None,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取所有知识库条目"""
query = db.query(KbItem)
if category_id:
query = query.filter(KbItem.category_id == category_id)
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(KbItem.title.like(kw), KbItem.summary.like(kw)))
total = query.count()
items = query.order_by(KbItem.created_at.desc()).offset((page - 1) * size).limit(size).all()
result = []
for item in items:
post = db.query(Post).filter(Post.id == item.post_id).first()
cat = db.query(KbCategory).filter(KbCategory.id == item.category_id).first() if item.category_id else None
author = db.query(User).filter(User.id == item.added_by).first() if item.added_by else None
post_author = db.query(User).filter(User.id == post.user_id).first() if post else None
result.append({
"id": item.id,
"title": item.title,
"summary": item.summary or "",
"category_id": item.category_id,
"category_name": cat.name if cat else "未分类",
"post_id": item.post_id,
"post_title": post.title if post else "(已删除)",
"post_author_name": post_author.username if post_author else "",
"added_by_name": author.username if author else "",
"is_active": item.is_active,
"sort_order": item.sort_order,
"created_at": item.created_at.isoformat() if item.created_at else None,
})
return {"items": result, "total": total, "page": page, "size": size}
@router.post("/admin/items")
def admin_add_items(
data: KbItemAdd,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""从帖子批量添加到知识库"""
added = 0
skipped = 0
for post_id in data.post_ids:
# 检查是否已存在
exists = db.query(KbItem).filter(KbItem.post_id == post_id).first()
if exists:
skipped += 1
continue
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
skipped += 1
continue
item = KbItem(
post_id=post_id,
category_id=data.category_id,
title=post.title,
summary=post.content[:200] if post.content else "",
added_by=admin.id,
)
db.add(item)
added += 1
db.commit()
return {"message": f"已添加 {added} 条,跳过 {skipped} 条(已存在或不存在)", "added": added, "skipped": skipped}
@router.put("/admin/items/{item_id}")
def admin_update_item(
item_id: int,
data: KbItemUpdate,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""更新知识库条目"""
item = db.query(KbItem).filter(KbItem.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="条目不存在")
updates = data.dict(exclude_none=True)
for key, value in updates.items():
setattr(item, key, value)
db.commit()
return {"message": "更新成功"}
@router.delete("/admin/items/{item_id}")
def admin_delete_item(
item_id: int,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""删除知识库条目"""
item = db.query(KbItem).filter(KbItem.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="条目不存在")
db.delete(item)
db.commit()
return {"message": "删除成功"}
@router.get("/admin/posts-for-pick")
def admin_posts_for_pick(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=50),
keyword: Optional[str] = None,
category: Optional[str] = None,
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""获取可选帖子列表(排除已加入知识库的)"""
existing_post_ids = [r[0] for r in db.query(KbItem.post_id).all()]
query = db.query(Post).filter(Post.is_public == True)
if existing_post_ids:
query = query.filter(~Post.id.in_(existing_post_ids))
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(Post.title.like(kw), Post.content.like(kw)))
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * size).limit(size).all()
result = []
for p in posts:
author = db.query(User).filter(User.id == p.user_id).first()
result.append({
"id": p.id,
"title": p.title,
"category": p.category,
"content_preview": p.content[:100] if p.content else "",
"author_name": author.username if author else "",
"like_count": p.like_count,
"view_count": p.view_count,
"created_at": p.created_at.isoformat() if p.created_at else None,
})
return {"items": result, "total": total, "page": page, "size": size}
# --- 管理员统计 ---
@router.get("/admin/stats")
def admin_kb_stats(
admin: User = Depends(get_admin_user),
db: Session = Depends(get_db),
):
"""管理员统计数据"""
total_items = db.query(sa_func.count(KbItem.id)).scalar() or 0
active_items = db.query(sa_func.count(KbItem.id)).filter(KbItem.is_active == True).scalar() or 0
total_categories = db.query(sa_func.count(KbCategory.id)).scalar() or 0
total_views = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "view").scalar() or 0
total_searches = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "search").scalar() or 0
total_ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(KbAccessLog.action == "ai_chat").scalar() or 0
# 最近7天趋势
from datetime import date
today = date.today()
daily_stats = []
for i in range(6, -1, -1):
d = today - timedelta(days=i)
day_start = datetime.combine(d, datetime.min.time())
day_end = datetime.combine(d, datetime.max.time())
views = db.query(sa_func.count(KbAccessLog.id)).filter(
KbAccessLog.action == "view", KbAccessLog.created_at.between(day_start, day_end)
).scalar() or 0
searches = db.query(sa_func.count(KbAccessLog.id)).filter(
KbAccessLog.action == "search", KbAccessLog.created_at.between(day_start, day_end)
).scalar() or 0
ai_chats = db.query(sa_func.count(KbAccessLog.id)).filter(
KbAccessLog.action == "ai_chat", KbAccessLog.created_at.between(day_start, day_end)
).scalar() or 0
daily_stats.append({"date": d.isoformat(), "views": views, "searches": searches, "ai_chats": ai_chats})
return {
"total_items": total_items,
"active_items": active_items,
"total_categories": total_categories,
"total_views": total_views,
"total_searches": total_searches,
"total_ai_chats": total_ai_chats,
"daily_stats": daily_stats,
}

360
backend/routers/nav.py Normal file
View File

@@ -0,0 +1,360 @@
"""导航站路由"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import func as sa_func
from pydantic import BaseModel
from typing import Optional
from database import get_db
from models.user import User
from models.nav_category import NavCategory
from models.nav_link import NavLink
from routers.auth import get_current_user, get_admin_user
router = APIRouter()
# ========== Schemas ==========
class NavCategoryCreate(BaseModel):
name: str
icon: str = ""
class NavCategoryUpdate(BaseModel):
name: Optional[str] = None
icon: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class NavLinkCreate(BaseModel):
category_id: int
name: str
url: str
icon: str = ""
description: str = ""
class NavLinkUpdate(BaseModel):
category_id: Optional[int] = None
name: Optional[str] = None
url: Optional[str] = None
icon: Optional[str] = None
description: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class NavLinkSubmit(BaseModel):
"""用户提交导航链接"""
category_id: int
name: str
url: str
icon: str = ""
description: str = ""
class NavLinkReview(BaseModel):
"""审核操作"""
action: str # approve / reject
reject_reason: str = ""
# ========== 管理员接口 ==========
@router.get("/admin/categories")
def admin_list_categories(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取所有导航分类(含禁用)"""
cats = db.query(NavCategory).order_by(NavCategory.sort_order, NavCategory.id).all()
return [
{
"id": c.id, "name": c.name, "icon": c.icon,
"sort_order": c.sort_order, "is_active": c.is_active,
"link_count": db.query(sa_func.count(NavLink.id)).filter(NavLink.category_id == c.id).scalar() or 0,
}
for c in cats
]
@router.post("/admin/categories")
def admin_create_category(
data: NavCategoryCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""新增导航分类"""
existing = db.query(NavCategory).filter(NavCategory.name == data.name).first()
if existing:
raise HTTPException(status_code=400, detail="分类名称已存在")
max_order = db.query(sa_func.max(NavCategory.sort_order)).scalar() or 0
cat = NavCategory(name=data.name, icon=data.icon, sort_order=max_order + 1)
db.add(cat)
db.commit()
db.refresh(cat)
return {"id": cat.id, "name": cat.name, "icon": cat.icon, "sort_order": cat.sort_order, "is_active": cat.is_active}
@router.put("/admin/categories/{cat_id}")
def admin_update_category(
cat_id: int,
data: NavCategoryUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""编辑导航分类"""
cat = db.query(NavCategory).filter(NavCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
if data.name is not None:
dup = db.query(NavCategory).filter(NavCategory.name == data.name, NavCategory.id != cat_id).first()
if dup:
raise HTTPException(status_code=400, detail="分类名称已存在")
cat.name = data.name
if data.icon is not None:
cat.icon = data.icon
if data.sort_order is not None:
cat.sort_order = data.sort_order
if data.is_active is not None:
cat.is_active = data.is_active
db.commit()
db.refresh(cat)
return {"id": cat.id, "name": cat.name, "icon": cat.icon, "sort_order": cat.sort_order, "is_active": cat.is_active}
@router.delete("/admin/categories/{cat_id}")
def admin_delete_category(
cat_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""删除导航分类(级联删除链接)"""
cat = db.query(NavCategory).filter(NavCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
db.query(NavLink).filter(NavLink.category_id == cat_id).delete()
db.delete(cat)
db.commit()
return {"message": "删除成功"}
@router.get("/admin/links")
def admin_list_links(
category_id: Optional[int] = None,
status: Optional[str] = None,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取导航链接列表"""
query = db.query(NavLink)
if category_id is not None:
query = query.filter(NavLink.category_id == category_id)
if status is not None:
query = query.filter(NavLink.status == status)
links = query.order_by(NavLink.sort_order, NavLink.id).all()
return [
{
"id": l.id, "category_id": l.category_id, "name": l.name,
"url": l.url, "icon": l.icon, "description": l.description,
"sort_order": l.sort_order, "is_active": l.is_active,
"status": l.status, "submitted_by": l.submitted_by,
"reject_reason": l.reject_reason or "",
}
for l in links
]
@router.post("/admin/links")
def admin_create_link(
data: NavLinkCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""新增导航链接"""
cat = db.query(NavCategory).filter(NavCategory.id == data.category_id).first()
if not cat:
raise HTTPException(status_code=400, detail="分类不存在")
max_order = db.query(sa_func.max(NavLink.sort_order)).filter(NavLink.category_id == data.category_id).scalar() or 0
link = NavLink(
category_id=data.category_id, name=data.name, url=data.url,
icon=data.icon, description=data.description, sort_order=max_order + 1,
)
db.add(link)
db.commit()
db.refresh(link)
return {
"id": link.id, "category_id": link.category_id, "name": link.name,
"url": link.url, "icon": link.icon, "description": link.description,
"sort_order": link.sort_order, "is_active": link.is_active,
}
@router.put("/admin/links/{link_id}")
def admin_update_link(
link_id: int,
data: NavLinkUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""编辑导航链接"""
link = db.query(NavLink).filter(NavLink.id == link_id).first()
if not link:
raise HTTPException(status_code=404, detail="链接不存在")
if data.category_id is not None:
link.category_id = data.category_id
if data.name is not None:
link.name = data.name
if data.url is not None:
link.url = data.url
if data.icon is not None:
link.icon = data.icon
if data.description is not None:
link.description = data.description
if data.sort_order is not None:
link.sort_order = data.sort_order
if data.is_active is not None:
link.is_active = data.is_active
db.commit()
db.refresh(link)
return {
"id": link.id, "category_id": link.category_id, "name": link.name,
"url": link.url, "icon": link.icon, "description": link.description,
"sort_order": link.sort_order, "is_active": link.is_active,
}
@router.delete("/admin/links/{link_id}")
def admin_delete_link(
link_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""删除导航链接"""
link = db.query(NavLink).filter(NavLink.id == link_id).first()
if not link:
raise HTTPException(status_code=404, detail="链接不存在")
db.delete(link)
db.commit()
return {"message": "删除成功"}
@router.get("/admin/pending-count")
def admin_pending_count(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取待审核数量"""
count = db.query(sa_func.count(NavLink.id)).filter(NavLink.status == "pending").scalar() or 0
return {"count": count}
@router.put("/admin/links/{link_id}/review")
def admin_review_link(
link_id: int,
data: NavLinkReview,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""审核导航链接"""
link = db.query(NavLink).filter(NavLink.id == link_id).first()
if not link:
raise HTTPException(status_code=404, detail="链接不存在")
if data.action == "approve":
link.status = "approved"
link.is_active = True
link.reject_reason = ""
elif data.action == "reject":
link.status = "rejected"
link.is_active = False
link.reject_reason = data.reject_reason
else:
raise HTTPException(status_code=400, detail="无效操作,请使用 approve 或 reject")
db.commit()
db.refresh(link)
return {
"id": link.id, "status": link.status,
"is_active": link.is_active, "reject_reason": link.reject_reason or "",
}
# ========== 用户提交接口 ==========
@router.post("/submit")
def user_submit_link(
data: NavLinkSubmit,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
"""用户提交导航网站(需管理员审核)"""
cat = db.query(NavCategory).filter(NavCategory.id == data.category_id).first()
if not cat:
raise HTTPException(status_code=400, detail="分类不存在")
existing = db.query(NavLink).filter(NavLink.url == data.url).first()
if existing:
raise HTTPException(status_code=400, detail="该网站已被提交过")
link = NavLink(
category_id=data.category_id, name=data.name, url=data.url,
icon=data.icon, description=data.description,
status="pending", submitted_by=user.id, is_active=False,
)
db.add(link)
db.commit()
db.refresh(link)
return {
"id": link.id, "name": link.name, "url": link.url,
"status": link.status, "message": "提交成功,等待管理员审核",
}
@router.get("/my-submissions")
def user_my_submissions(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
"""用户查看自己提交的记录"""
links = db.query(NavLink).filter(NavLink.submitted_by == user.id).order_by(NavLink.id.desc()).all()
return [
{
"id": l.id, "name": l.name, "url": l.url, "icon": l.icon,
"description": l.description, "status": l.status,
"reject_reason": l.reject_reason or "",
"created_at": l.created_at.isoformat() if l.created_at else "",
}
for l in links
]
# ========== 公开接口 ==========
@router.get("/public")
def get_public_nav(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取所有启用分类及其下启用的链接"""
cats = db.query(NavCategory).filter(NavCategory.is_active == True).order_by(NavCategory.sort_order, NavCategory.id).all()
result = []
for c in cats:
links = (
db.query(NavLink)
.filter(NavLink.category_id == c.id, NavLink.is_active == True, NavLink.status == "approved")
.order_by(NavLink.sort_order, NavLink.id)
.all()
)
if links:
result.append({
"id": c.id, "name": c.name, "icon": c.icon,
"links": [
{"id": l.id, "name": l.name, "url": l.url, "icon": l.icon, "description": l.description}
for l in links
],
})
return result
@router.get("/public/categories")
def get_public_categories(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取所有启用分类(供用户提交时选择)"""
cats = db.query(NavCategory).filter(NavCategory.is_active == True).order_by(NavCategory.sort_order, NavCategory.id).all()
return [{"id": c.id, "name": c.name, "icon": c.icon} for c in cats]

View File

@@ -0,0 +1,89 @@
"""消息通知路由"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from sqlalchemy import func
from database import get_db
from models.user import User
from models.notification import Notification
from routers.auth import get_current_user
router = APIRouter()
@router.get("")
def get_notifications(
page: int = 1,
page_size: int = 30,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取通知列表"""
notifs = (
db.query(Notification)
.filter(Notification.user_id == current_user.id)
.order_by(Notification.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
# 获取触发用户信息
from_ids = list(set(n.from_user_id for n in notifs if n.from_user_id))
users = db.query(User).filter(User.id.in_(from_ids)).all() if from_ids else []
user_map = {u.id: u for u in users}
return [
{
"id": n.id,
"type": n.type,
"content": n.content,
"related_id": n.related_id,
"is_read": n.is_read,
"created_at": n.created_at,
"from_user": {
"id": n.from_user_id,
"username": user_map[n.from_user_id].username,
"avatar": user_map[n.from_user_id].avatar,
} if n.from_user_id and n.from_user_id in user_map else None,
}
for n in notifs
]
@router.get("/unread-count")
def get_unread_count(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取未读通知数量"""
count = (
db.query(func.count(Notification.id))
.filter(Notification.user_id == current_user.id, Notification.is_read == False)
.scalar()
)
return {"count": count}
@router.put("/read-all")
def read_all(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""全部标为已读"""
db.query(Notification).filter(
Notification.user_id == current_user.id, Notification.is_read == False
).update({"is_read": True})
db.commit()
return {"message": "已全部标为已读"}
@router.put("/{notif_id}/read")
def read_one(
notif_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""单条标为已读"""
db.query(Notification).filter(
Notification.id == notif_id, Notification.user_id == current_user.id
).update({"is_read": True})
db.commit()
return {"message": "已标为已读"}

440
backend/routers/posts.py Normal file
View File

@@ -0,0 +1,440 @@
"""经验知识库路由"""
import json
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import or_
from database import get_db
from models.user import User
from models.post import Post
from models.comment import Comment
from models.like import Like, Collect
from models.follow import Follow
from models.notification import Notification
from models.attachment import Attachment
from schemas.post import (
PostCreate, PostUpdate, PostResponse, PostListResponse,
CommentCreate, CommentResponse,
)
from routers.auth import get_current_user
router = APIRouter()
import re
def _extract_cover_image(content: str) -> str:
"""从Markdown内容中提取第一张图片作为封面"""
if not content:
return ""
match = re.search(r'!\[.*?\]\((.*?)\)', content)
if match:
return match.group(1)
img_match = re.search(r'<img[^>]+src=["\']([^"\']+)["\']', content)
if img_match:
return img_match.group(1)
return ""
def _enrich_post_with_author(post: Post, db: Session) -> dict:
"""为帖子附加作者信息(用于信息流)"""
author = db.query(User).filter(User.id == post.user_id).first()
return {
"id": post.id, "title": post.title, "content": post.content[:200],
"category": post.category, "tags": post.tags,
"cover_image": _extract_cover_image(post.content),
"view_count": post.view_count, "like_count": post.like_count,
"comment_count": post.comment_count, "collect_count": post.collect_count,
"created_at": post.created_at, "updated_at": post.updated_at,
"author": {
"id": post.user_id,
"username": author.username if author else "未知",
"avatar": author.avatar if author else "",
},
}
@router.get("/feed")
def get_feed(
page: int = 1, page_size: int = 20,
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""关注的人的帖子流"""
following_ids = [f.following_id for f in db.query(Follow.following_id).filter(Follow.follower_id == current_user.id).all()]
if not following_ids:
return {"items": [], "total": 0, "page": page, "page_size": page_size}
query = db.query(Post).filter(Post.user_id.in_(following_ids), Post.is_public == True, Post.is_draft == False)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = (
query.order_by(Post.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
@router.get("/hot")
def get_hot_posts(
page: int = 1, page_size: int = 20,
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""热门帖子(按热度排序)"""
query = db.query(Post).filter(Post.is_public == True, Post.is_draft == False)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = (
query.order_by((Post.like_count * 3 + Post.comment_count * 2 + Post.view_count).desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
@router.get("/latest")
def get_latest_posts(
page: int = 1, page_size: int = 20,
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""最新帖子"""
query = db.query(Post).filter(Post.is_public == True, Post.is_draft == False)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = (
query.order_by(Post.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
return {"items": [_enrich_post_with_author(p, db) for p in posts], "total": total, "page": page, "page_size": page_size}
@router.get("/drafts")
def get_drafts(
page: int = 1, page_size: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户的草稿列表"""
query = db.query(Post).filter(Post.user_id == current_user.id, Post.is_draft == True)
total = query.count()
posts = query.order_by(Post.updated_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return {
"items": [_enrich_post(p, db, current_user.id) for p in posts],
"total": total, "page": page, "page_size": page_size,
}
def _enrich_post(post: Post, db: Session, current_user_id: int = None) -> PostResponse:
"""填充帖子额外字段"""
author = db.query(User).filter(User.id == post.user_id).first()
result = PostResponse.model_validate(post)
result.author_name = author.username if author else "未知用户"
if current_user_id:
result.is_liked = db.query(Like).filter(
Like.post_id == post.id, Like.user_id == current_user_id
).first() is not None
result.is_collected = db.query(Collect).filter(
Collect.post_id == post.id, Collect.user_id == current_user_id
).first() is not None
return result
@router.get("", response_model=PostListResponse)
def get_posts(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
category: Optional[str] = None,
tag: Optional[str] = None,
user_id: Optional[int] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取帖子列表"""
query = db.query(Post).filter(
or_(Post.is_public == True, Post.user_id == current_user.id),
Post.is_draft == False,
)
if category:
query = query.filter(Post.category == category)
if tag:
query = query.filter(Post.tags.contains(tag))
if user_id:
query = query.filter(Post.user_id == user_id)
total = query.count()
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return PostListResponse(
items=[_enrich_post(p, db, current_user.id) for p in posts],
total=total,
page=page,
page_size=page_size,
)
@router.post("", response_model=PostResponse)
def create_post(
data: PostCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""发布经验帖"""
post = Post(
user_id=current_user.id,
title=data.title,
content=data.content,
category=data.category,
tags=json.dumps(data.tags, ensure_ascii=False),
is_public=data.is_public,
is_draft=data.is_draft,
)
db.add(post)
db.commit()
db.refresh(post)
return _enrich_post(post, db, current_user.id)
@router.get("/{post_id}", response_model=PostResponse)
def get_post(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取帖子详情"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
if not post.is_public and post.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问")
# 增加浏览量
post.view_count += 1
db.commit()
db.refresh(post)
return _enrich_post(post, db, current_user.id)
@router.put("/{post_id}", response_model=PostResponse)
def update_post(
post_id: int,
data: PostUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""编辑帖子"""
# 管理员可以编辑任意帖子,普通用户只能编辑自己的
if current_user.is_admin:
post = db.query(Post).filter(Post.id == post_id).first()
else:
post = db.query(Post).filter(Post.id == post_id, Post.user_id == current_user.id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在或无权编辑")
if data.title is not None:
post.title = data.title
if data.content is not None:
post.content = data.content
if data.category is not None:
post.category = data.category
if data.tags is not None:
post.tags = json.dumps(data.tags, ensure_ascii=False)
if data.is_public is not None:
post.is_public = data.is_public
if data.is_draft is not None:
post.is_draft = data.is_draft
db.commit()
db.refresh(post)
return _enrich_post(post, db, current_user.id)
@router.delete("/{post_id}")
def delete_post(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除帖子"""
post = db.query(Post).filter(Post.id == post_id, Post.user_id == current_user.id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在或无权删除")
db.query(Comment).filter(Comment.post_id == post_id).delete()
db.query(Like).filter(Like.post_id == post_id).delete()
db.query(Collect).filter(Collect.post_id == post_id).delete()
db.query(Attachment).filter(Attachment.post_id == post_id).delete()
db.delete(post)
db.commit()
return {"message": "删除成功"}
@router.post("/{post_id}/like")
def toggle_like(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""点赞/取消点赞"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
existing = db.query(Like).filter(
Like.post_id == post_id, Like.user_id == current_user.id
).first()
if existing:
db.delete(existing)
post.like_count = max(0, post.like_count - 1)
db.commit()
return {"liked": False, "like_count": post.like_count}
else:
db.add(Like(post_id=post_id, user_id=current_user.id))
post.like_count += 1
# 通知帖子作者
if post.user_id != current_user.id:
db.add(Notification(
user_id=post.user_id, type="like",
content=f"{current_user.username} 赞了你的文章「{post.title[:30]}",
from_user_id=current_user.id, related_id=post_id,
))
db.commit()
return {"liked": True, "like_count": post.like_count}
@router.post("/{post_id}/collect")
def toggle_collect(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""收藏/取消收藏"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
existing = db.query(Collect).filter(
Collect.post_id == post_id, Collect.user_id == current_user.id
).first()
if existing:
db.delete(existing)
post.collect_count = max(0, post.collect_count - 1)
db.commit()
return {"collected": False, "collect_count": post.collect_count}
else:
db.add(Collect(post_id=post_id, user_id=current_user.id))
post.collect_count += 1
db.commit()
return {"collected": True, "collect_count": post.collect_count}
@router.get("/{post_id}/comments", response_model=List[CommentResponse])
def get_comments(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取评论列表"""
comments = (
db.query(Comment)
.filter(Comment.post_id == post_id)
.order_by(Comment.created_at.asc())
.all()
)
results = []
for c in comments:
author = db.query(User).filter(User.id == c.user_id).first()
r = CommentResponse.model_validate(c)
r.author_name = author.username if author else "未知用户"
results.append(r)
return results
@router.post("/{post_id}/comments", response_model=CommentResponse)
def create_comment(
post_id: int,
data: CommentCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""发表评论"""
post = db.query(Post).filter(Post.id == post_id).first()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
comment = Comment(
post_id=post_id,
user_id=current_user.id,
content=data.content,
)
db.add(comment)
post.comment_count += 1
# 通知帖子作者
if post.user_id != current_user.id:
db.add(Notification(
user_id=post.user_id, type="comment",
content=f"{current_user.username} 评论了你的文章「{post.title[:30]}",
from_user_id=current_user.id, related_id=post_id,
))
db.commit()
db.refresh(comment)
result = CommentResponse.model_validate(comment)
result.author_name = current_user.username
return result
@router.get("/{post_id}/attachments")
def get_attachments(
post_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取帖子的附件列表"""
attachments = (
db.query(Attachment)
.filter(Attachment.post_id == post_id)
.order_by(Attachment.created_at.asc())
.all()
)
return [
{
"id": a.id,
"filename": a.filename,
"url": a.url,
"file_size": a.file_size,
"file_type": a.file_type,
"created_at": a.created_at,
}
for a in attachments
]
@router.delete("/{post_id}/attachments/{attachment_id}")
def delete_attachment(
post_id: int,
attachment_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除附件(仅作者)"""
attachment = db.query(Attachment).filter(
Attachment.id == attachment_id,
Attachment.post_id == post_id,
Attachment.user_id == current_user.id,
).first()
if not attachment:
raise HTTPException(status_code=404, detail="附件不存在或无权删除")
db.delete(attachment)
db.commit()
return {"message": "删除成功"}

410
backend/routers/projects.py Normal file
View File

@@ -0,0 +1,410 @@
"""开源项目路由"""
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import func as sa_func, or_
from pydantic import BaseModel
from typing import Optional, List
from database import get_db
from models.user import User
from models.project import Project
from models.like import ProjectCollect
from routers.auth import get_current_user, get_admin_user
router = APIRouter()
GITHUB_API = "https://api.github.com"
# ========== Schemas ==========
class ProjectCreate(BaseModel):
name: str
description: str = ""
url: str
homepage: str = ""
icon: str = ""
language: str = ""
category: str = ""
stars: int = 0
forks: int = 0
class ProjectUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
url: Optional[str] = None
homepage: Optional[str] = None
icon: Optional[str] = None
language: Optional[str] = None
category: Optional[str] = None
stars: Optional[int] = None
forks: Optional[int] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
def _project_to_dict(p: Project, is_collected: bool = False) -> dict:
return {
"id": p.id,
"name": p.name,
"description": p.description or "",
"url": p.url,
"homepage": p.homepage or "",
"icon": p.icon or "",
"language": p.language or "",
"category": p.category or "",
"stars": p.stars or 0,
"forks": p.forks or 0,
"collect_count": getattr(p, 'collect_count', 0) or 0,
"is_collected": is_collected,
"sort_order": p.sort_order,
"is_active": p.is_active,
"created_at": p.created_at.isoformat() if p.created_at else None,
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
}
def _with_collect_status(items: list, user_id: int, db: Session) -> list:
"""批量查询用户是否已收藏"""
if not items:
return []
project_ids = [p.id for p in items]
collected_ids = set(
r[0] for r in db.query(ProjectCollect.project_id)
.filter(ProjectCollect.project_id.in_(project_ids), ProjectCollect.user_id == user_id)
.all()
)
return [_project_to_dict(p, p.id in collected_ids) for p in items]
# ========== 管理员接口 ==========
@router.get("/admin/list")
def admin_list_projects(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
category: Optional[str] = None,
keyword: Optional[str] = None,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""获取所有项目(含禁用),支持分页"""
query = db.query(Project)
if category:
query = query.filter(Project.category == category)
if keyword:
kw = f"%{keyword}%"
query = query.filter(or_(Project.name.like(kw), Project.description.like(kw)))
total = query.count()
items = query.order_by(Project.sort_order, Project.id.desc()).offset((page - 1) * size).limit(size).all()
return {
"total": total,
"items": [_project_to_dict(p) for p in items],
}
@router.post("/admin")
def admin_create_project(
data: ProjectCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""新增项目"""
max_order = db.query(sa_func.max(Project.sort_order)).scalar() or 0
proj = Project(
name=data.name,
description=data.description,
url=data.url,
homepage=data.homepage,
icon=data.icon,
language=data.language,
category=data.category,
stars=data.stars,
forks=data.forks,
sort_order=max_order + 1,
)
db.add(proj)
db.commit()
db.refresh(proj)
return _project_to_dict(proj)
@router.put("/admin/{project_id}")
def admin_update_project(
project_id: int,
data: ProjectUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""编辑项目"""
proj = db.query(Project).filter(Project.id == project_id).first()
if not proj:
raise HTTPException(status_code=404, detail="项目不存在")
for field in ["name", "description", "url", "homepage", "icon", "language", "category", "stars", "forks", "sort_order", "is_active"]:
val = getattr(data, field)
if val is not None:
setattr(proj, field, val)
db.commit()
db.refresh(proj)
return _project_to_dict(proj)
@router.delete("/admin/{project_id}")
def admin_delete_project(
project_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""删除项目"""
proj = db.query(Project).filter(Project.id == project_id).first()
if not proj:
raise HTTPException(status_code=404, detail="项目不存在")
db.delete(proj)
db.commit()
return {"message": "删除成功"}
# ========== 公开接口 ==========
@router.get("/hot")
def get_hot_projects(
page: int = Query(1, ge=1),
size: int = Query(12, ge=1, le=50),
category: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""热门项目(按 stars 降序)"""
query = db.query(Project).filter(Project.is_active == True)
if category:
query = query.filter(Project.category == category)
total = query.count()
items = query.order_by(Project.stars.desc(), Project.sort_order, Project.id.desc()).offset((page - 1) * size).limit(size).all()
return {"total": total, "items": _with_collect_status(items, current_user.id, db)}
@router.get("/latest")
def get_latest_projects(
page: int = Query(1, ge=1),
size: int = Query(12, ge=1, le=50),
category: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""最新项目(按创建时间降序)"""
query = db.query(Project).filter(Project.is_active == True)
if category:
query = query.filter(Project.category == category)
total = query.count()
items = query.order_by(Project.created_at.desc(), Project.id.desc()).offset((page - 1) * size).limit(size).all()
return {"total": total, "items": _with_collect_status(items, current_user.id, db)}
@router.get("/search")
def search_projects(
q: str = Query("", min_length=0),
page: int = Query(1, ge=1),
size: int = Query(12, ge=1, le=50),
category: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""搜索项目"""
query = db.query(Project).filter(Project.is_active == True)
if q.strip():
kw = f"%{q.strip()}%"
query = query.filter(or_(Project.name.like(kw), Project.description.like(kw)))
if category:
query = query.filter(Project.category == category)
total = query.count()
items = query.order_by(Project.stars.desc(), Project.id.desc()).offset((page - 1) * size).limit(size).all()
return {"total": total, "items": _with_collect_status(items, current_user.id, db)}
@router.get("/categories")
def get_project_categories(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取所有有项目的分类"""
rows = (
db.query(Project.category, sa_func.count(Project.id))
.filter(Project.is_active == True, Project.category != "")
.group_by(Project.category)
.order_by(sa_func.count(Project.id).desc())
.all()
)
return [{"name": r[0], "count": r[1]} for r in rows]
# ========== 收藏接口 ==========
@router.get("/my-collects")
def get_my_collects(
page: int = Query(1, ge=1),
size: int = Query(12, ge=1, le=50),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前用户收藏的项目"""
subq = db.query(ProjectCollect.project_id).filter(ProjectCollect.user_id == current_user.id).subquery()
query = db.query(Project).filter(Project.id.in_(subq), Project.is_active == True)
total = query.count()
items = query.order_by(Project.id.desc()).offset((page - 1) * size).limit(size).all()
return {"total": total, "items": [_project_to_dict(p, True) for p in items]}
@router.post("/{project_id}/collect")
def toggle_project_collect(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""收藏/取消收藏项目"""
proj = db.query(Project).filter(Project.id == project_id).first()
if not proj:
raise HTTPException(status_code=404, detail="项目不存在")
existing = db.query(ProjectCollect).filter(
ProjectCollect.project_id == project_id, ProjectCollect.user_id == current_user.id
).first()
if existing:
db.delete(existing)
proj.collect_count = max(0, (proj.collect_count or 0) - 1)
db.commit()
return {"collected": False, "collect_count": proj.collect_count}
else:
db.add(ProjectCollect(project_id=project_id, user_id=current_user.id))
proj.collect_count = (proj.collect_count or 0) + 1
db.commit()
return {"collected": True, "collect_count": proj.collect_count}
# ========== GitHub 搜索(公共 + 管理员通用) ==========
async def _github_search_impl(q: str, sort: str, page: int, per_page: int):
"""GitHub 搜索核心实现"""
url = f"{GITHUB_API}/search/repositories"
params = {"q": q, "sort": sort, "order": "desc", "page": page, "per_page": per_page}
headers = {"Accept": "application/vnd.github+json"}
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.get(url, params=params, headers=headers)
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"GitHub API 返回 {resp.status_code}")
data = resp.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="GitHub API 请求超时")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"网络请求失败: {str(e)}")
items = []
for repo in data.get("items", []):
items.append({
"github_id": repo["id"],
"name": repo.get("name", ""),
"full_name": repo.get("full_name", ""),
"description": repo.get("description") or "",
"url": repo.get("html_url", ""),
"homepage": repo.get("homepage") or "",
"icon": repo.get("owner", {}).get("avatar_url", ""),
"language": repo.get("language") or "",
"stars": repo.get("stargazers_count", 0),
"forks": repo.get("forks_count", 0),
"topics": repo.get("topics", []),
"created_at": repo.get("created_at", ""),
"updated_at": repo.get("updated_at", ""),
})
return {"total": data.get("total_count", 0), "items": items}
@router.get("/github-search")
async def public_github_search(
q: str = Query(..., min_length=1),
sort: str = Query("stars"),
page: int = Query(1, ge=1),
per_page: int = Query(12, ge=1, le=30),
user: User = Depends(get_current_user),
):
"""公开 GitHub 搜索(登录用户可用)"""
return await _github_search_impl(q, sort, page, per_page)
# ========== GitHub 导入接口(管理员) ==========
@router.get("/admin/github-search")
async def github_search(
q: str = Query(..., min_length=1),
sort: str = Query("stars"),
page: int = Query(1, ge=1),
per_page: int = Query(12, ge=1, le=30),
admin: User = Depends(get_admin_user),
):
"""管理员 GitHub 搜索"""
return await _github_search_impl(q, sort, page, per_page)
class GitHubImportItem(BaseModel):
name: str
description: str = ""
url: str
homepage: str = ""
icon: str = ""
language: str = ""
category: str = ""
stars: int = 0
forks: int = 0
class GitHubImportRequest(BaseModel):
items: List[GitHubImportItem]
@router.post("/admin/github-import")
def github_import(
data: GitHubImportRequest,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""批量导入 GitHub 项目"""
imported = 0
skipped = 0
max_order = db.query(sa_func.max(Project.sort_order)).scalar() or 0
for item in data.items:
existing = db.query(Project).filter(Project.url == item.url).first()
if existing:
skipped += 1
continue
max_order += 1
proj = Project(
name=item.name,
description=item.description,
url=item.url,
homepage=item.homepage,
icon=item.icon,
language=item.language,
category=item.category,
stars=item.stars,
forks=item.forks,
sort_order=max_order,
)
db.add(proj)
imported += 1
db.commit()
return {"imported": imported, "skipped": skipped}
# ========== 项目详情(通配路由放最后) ==========
@router.get("/{project_id}")
def get_project_detail(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""项目详情"""
proj = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
if not proj:
raise HTTPException(status_code=404, detail="项目不存在")
return _project_to_dict(proj)

View File

@@ -0,0 +1,313 @@
"""需求理解助手路由"""
import json
import os
import uuid
from typing import List
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from models.user import User
from models.conversation import Conversation, Message
from schemas.conversation import (
RequirementAnalyzeRequest, ConversationResponse,
ConversationDetail, MessageResponse,
)
from routers.auth import get_current_user
from services.ai_service import ai_service
from config import UPLOAD_DIR
router = APIRouter()
REQUIREMENT_SYSTEM_PROMPT = """# 角色定义
你同时具备两个身份:
1. **资深产品经理10年+**擅长从模糊信息中提炼需求本质精通用户故事、MECE拆解、优先级排序、验收标准制定
2. **高级全栈程序员10年+**:做过大量项目,对功能的复杂度、开发工作量有精准直觉,能一眼看出哪些需求是"看似简单实则巨坑"
你服务的对象是**程序员**,他们会把甲方发来的原始内容(口语化描述、语音转文字、截图、聊天记录、需求文档等)发给你。你需要站在「既懂业务又懂技术」的双重视角,将其转化为**清晰明确、可直接进入开发的结构化需求**。
> ⚠️ 本助手专注于**需求理解与分析**。技术选型、数据库设计、API设计、架构方案等请移步「架构选型AI助手」。
# 核心理念
- **需求层面**:甲方说的不一定是他真正想要的,你要透过表面描述挖掘真实诉求
- **程序员直觉**:每个功能你都要心里过一遍复杂度,标注哪些"看起来简单但实际很复杂"
- **落地导向**:不出空中楼阁式的分析,每条功能都要能对应到具体的开发任务
# 分析框架
收到用户输入后,按以下框架进行系统性分析:
## 第一步:需求还原(产品经理视角)
- 理解甲方的**核心意图**:到底想解决什么问题?服务什么业务场景?
- 识别**目标用户**是谁,使用频率、核心使用路径是什么
- 区分"真实需求""表面描述",找出甲方没说但一定需要的**隐性需求**
- 判断产品定位:工具型/平台型/内容型ToB/ToC
## 第二步:功能拆解(遵循 MECE 原则)
将需求拆解为相互独立、完全穷尽的功能模块,每个功能需包含:
- 功能名称 + 具体说明
- 优先级P0 核心必做 / P1 重要 / P2 锦上添花)
- 涉及的用户角色
- 复杂度预判(简单/中等/复杂)+ 复杂原因说明
## 第三步:用户故事与验收标准
将核心功能转写为标准用户故事:
> 作为【角色】,我希望【功能】,以便【价值/目的】
为 P0 功能补充验收标准AC采用 Given-When-Then 格式:
> 假设【前置条件】,当【用户操作】,那么【系统响应】
## 第四步:复杂度预警(程序员视角)
基于你做过大量项目的经验,标注:
- **隐藏复杂度**:哪些功能看似简单但实现起来有坑(如"支持实时聊天"看似一句话实际涉及WebSocket、消息队列、已读回执等
- **高风险功能**:涉及支付、权限、数据一致性等需要特别慎重的功能
- **工期杀手**:容易严重超出预期工时的功能,提前预警
- **工期粗估**按模块给出大致工时范围x-x天帮助程序员评估排期
## 第五步:边界与待确认项
- 需求中**含糊不清**需要甲方确认的关键问题
- 容易遗漏的**边缘场景**(空状态、异常流、权限边界、数据量极值、并发操作)
- 甲方可能还没想到但**一定会追加**的需求(基于经验预判)
# 输出规范
严格使用以下 Markdown 结构输出:
---
## 📋 需求概述
> 用2-3句话概括这是什么产品/功能,解决谁的什么问题,核心价值是什么。
> 产品定位:【工具型/平台型/内容型】 | 【ToB/ToC】
## 👥 用户角色
| 角色 | 说明 | 关键诉求 | 使用频率 |
|------|------|---------|---------|
## 🧩 功能清单
| 优先级 | 模块 | 功能项 | 功能说明 | 涉及角色 | 复杂度 |
|--------|------|--------|---------|---------|--------|
| P0 | xxx | xxx | xxx | xxx | 🔴复杂 - 原因 |
| P0 | xxx | xxx | xxx | xxx | 🟢简单 |
| P1 | xxx | xxx | xxx | xxx | 🟡中等 |
## 📖 核心用户故事与验收标准
### US-1: 【用户故事标题】
- **故事**:作为【角色】,我希望【功能】,以便【价值】
- **验收标准**
- ✅ 假设【条件】,当【操作】,那么【预期结果】
- ✅ 假设【条件】,当【异常操作】,那么【兜底处理】
## ⚡ 复杂度预警(程序员必看)
### 🔴 隐藏复杂度
1. **【功能名】**看似xxx实际需要xxx建议xxx
### ⏱️ 工期粗估
| 模块 | 工时范围 | 说明 |
|------|---------|------|
| 合计 | x-x天 | 基于1名全栈开发者含开发+自测 |
### 🔮 甲方大概率会追加的需求
1. 【需求名】— 理由基于经验做了xxx之后甲方通常会要求xxx
## ❓ 待确认问题
> 以下问题会直接影响开发方案,建议优先与甲方确认:
1. **【问题】**:【为什么需要确认】→ 不确认的影响【xxx】
## 💡 风险提示
- 【业务风险、边缘场景、容易遗漏的点等】
---
> 💡 **下一步建议**需求确认后可将功能清单发送至「架构选型AI助手」获取技术选型、数据库设计、API接口设计等技术方案。
# 交互原则
1. **首次分析要全面**:即使信息不完整,也要基于已有信息给出最完整的分析,用待确认问题标注不确定的部分
2. **复杂度标注要诚实**:程序员最怕"这个很简单",要如实标注复杂度和潜在的坑
3. **追问要有价值**:只追问影响开发方案的关键问题,不问显而易见的
4. **语言贴近程序员**:用开发者能直接理解的术语,避免纯商业话术
5. **持续迭代**:用户补充信息后,在之前分析的基础上更新完善,而非重头开始
6. **务实不空谈**:每条建议都基于实际经验,不说"建议优化用户体验"这类空话
7. **工时要诚实**:给出合理区间,标注不确定因素"""
@router.post("/upload-image")
async def upload_image(
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
):
"""上传图片"""
# 验证文件类型
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
if file.content_type not in allowed_types:
raise HTTPException(status_code=400, detail="不支持的文件类型")
# 生成唯一文件名
ext = file.filename.split(".")[-1] if "." in file.filename else "png"
filename = f"{uuid.uuid4().hex}.{ext}"
filepath = os.path.join(UPLOAD_DIR, filename)
# 保存文件
content = await file.read()
with open(filepath, "wb") as f:
f.write(content)
return {"url": f"/uploads/{filename}"}
@router.get("/conversations", response_model=List[ConversationResponse])
def get_conversations(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取需求对话列表"""
conversations = (
db.query(Conversation)
.filter(Conversation.user_id == current_user.id, Conversation.type == "requirement")
.order_by(Conversation.updated_at.desc())
.all()
)
return [ConversationResponse.model_validate(c) for c in conversations]
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
def get_conversation_detail(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取对话详情"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
messages = (
db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.all()
)
result = ConversationDetail.model_validate(conv)
result.messages = [MessageResponse.model_validate(m) for m in messages]
return result
@router.post("/analyze")
async def analyze_requirement(
request: RequirementAnalyzeRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""分析需求 - 流式输出"""
# 创建或获取对话
if request.conversation_id:
conv = db.query(Conversation).filter(
Conversation.id == request.conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
else:
conv = Conversation(
user_id=current_user.id,
title=request.content[:50] if request.content else "新需求分析",
type="requirement",
)
db.add(conv)
db.commit()
db.refresh(conv)
# 保存用户消息
user_msg = Message(
conversation_id=conv.id,
role="user",
content=request.content,
image_urls=json.dumps(request.image_urls) if request.image_urls else "",
)
db.add(user_msg)
db.commit()
# 构建历史消息
history_msgs = (
db.query(Message)
.filter(Message.conversation_id == conv.id)
.order_by(Message.created_at.asc())
.all()
)
messages = []
for msg in history_msgs:
if msg.role == "user":
content = msg.content
# 如果有图片,添加图片描述提示
if msg.image_urls:
try:
urls = json.loads(msg.image_urls)
if urls:
content += f"\n\n[用户上传了{len(urls)}张图片]"
except json.JSONDecodeError:
pass
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "assistant", "content": msg.content})
# 确定任务类型
task_type = "multimodal" if request.image_urls else "reasoning"
# 流式调用AI
async def generate():
full_response = ""
try:
result = await ai_service.chat(
task_type=task_type,
messages=messages,
system_prompt=REQUIREMENT_SYSTEM_PROMPT,
stream=True,
model_config_id=request.model_config_id,
)
if isinstance(result, str):
# 非流式返回
full_response = result
yield f"data: {json.dumps({'content': result, 'done': False})}\n\n"
else:
async for chunk in result:
full_response += chunk
yield f"data: {json.dumps({'content': chunk, 'done': False})}\n\n"
except Exception as e:
error_msg = f"AI调用出错: {str(e)}"
full_response = error_msg
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
# 保存AI回复
ai_msg = Message(
conversation_id=conv.id,
role="assistant",
content=full_response,
)
db.add(ai_msg)
db.commit()
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@router.delete("/conversations/{conversation_id}")
def delete_conversation(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除对话"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
db.delete(conv)
db.commit()
return {"message": "删除成功"}

47
backend/routers/search.py Normal file
View File

@@ -0,0 +1,47 @@
"""搜索路由"""
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from sqlalchemy import or_
from database import get_db
from models.user import User
from models.post import Post
from schemas.post import PostResponse, PostListResponse
from routers.auth import get_current_user
from routers.posts import _enrich_post
router = APIRouter()
@router.get("", response_model=PostListResponse)
def search_posts(
q: str = Query(..., min_length=1, description="搜索关键词"),
category: Optional[str] = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""全文搜索帖子"""
query = db.query(Post).filter(
or_(Post.is_public == True, Post.user_id == current_user.id),
or_(
Post.title.contains(q),
Post.content.contains(q),
Post.tags.contains(q),
),
)
if category:
query = query.filter(Post.category == category)
total = query.count()
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return PostListResponse(
items=[_enrich_post(p, db, current_user.id) for p in posts],
total=total,
page=page,
page_size=page_size,
)

View File

@@ -0,0 +1,526 @@
"""共享API Hub路由"""
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from sqlalchemy.orm import Session
from sqlalchemy import func as sa_func
from pydantic import BaseModel
from typing import Optional
from datetime import datetime, timedelta
import time
import hashlib
from database import get_db
from config import SECRET_KEY
from models.user import User
from models.system_config import SystemConfig
from models.shared_api import SharedApiCategory, SharedApi, SharedApiLog
from routers.auth import get_current_user, get_admin_user
router = APIRouter()
# ========== 加密工具 ==========
_fernet = None
def _get_fernet():
global _fernet
if _fernet is None:
from cryptography.fernet import Fernet
import base64
# 从SECRET_KEY派生一个Fernet兼容的key
key = hashlib.sha256(SECRET_KEY.encode()).digest()
_fernet = Fernet(base64.urlsafe_b64encode(key))
return _fernet
def encrypt_key(plain: str) -> str:
if not plain:
return ""
return _get_fernet().encrypt(plain.encode()).decode()
def decrypt_key(encrypted: str) -> str:
if not encrypted:
return ""
try:
return _get_fernet().decrypt(encrypted.encode()).decode()
except Exception:
return ""
def mask_key(encrypted: str) -> str:
"""脱敏显示"""
plain = decrypt_key(encrypted)
if not plain:
return ""
if len(plain) <= 8:
return plain[:2] + "***"
return plain[:4] + "****" + plain[-4:]
# ========== Hub访问密码机制 ==========
def _get_hub_password(db: Session) -> str:
cfg = db.query(SystemConfig).filter(SystemConfig.key == "api_hub_password").first()
return cfg.value if cfg else ""
def _hub_password_version(db: Session) -> str:
"""返回密码哈希前8位作为版本标识密码变更后旧token自动失效"""
pwd = _get_hub_password(db)
return pwd[:8] if pwd else "none"
def _hash_password(pwd: str) -> str:
return hashlib.sha256(pwd.encode()).hexdigest()
def _create_hub_token(user_id: int, pwd_ver: str = "none") -> str:
"""创建Hub访问令牌简单签名2小时有效"""
from jose import jwt
exp = datetime.utcnow() + timedelta(hours=2)
return jwt.encode({"sub": str(user_id), "hub": True, "pv": pwd_ver, "exp": exp}, SECRET_KEY, algorithm="HS256")
def verify_hub_access(
x_hub_token: Optional[str] = Header(None),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""验证用户登录 + Hub访问令牌"""
if not x_hub_token:
raise HTTPException(status_code=403, detail="需要API Hub访问权限请先验证密码")
from jose import jwt, JWTError
try:
payload = jwt.decode(x_hub_token, SECRET_KEY, algorithms=["HS256"])
if not payload.get("hub"):
raise HTTPException(status_code=403, detail="无效的Hub令牌")
# 检查密码版本是否匹配
token_pv = payload.get("pv", "")
current_pv = _hub_password_version(db)
if token_pv != current_pv:
raise HTTPException(status_code=403, detail="密码已变更,请重新验证")
except JWTError:
raise HTTPException(status_code=403, detail="Hub令牌已过期请重新验证密码")
return current_user
# ========== Schemas ==========
class HubAuthRequest(BaseModel):
password: str
class CategoryCreate(BaseModel):
name: str
icon: str = ""
class CategoryUpdate(BaseModel):
name: Optional[str] = None
icon: Optional[str] = None
sort_order: Optional[int] = None
is_active: Optional[bool] = None
class ApiCreate(BaseModel):
category_id: Optional[int] = None
name: str
description: str = ""
base_url: str = ""
doc_url: str = ""
auth_type: str = "none"
api_key: str = "" # 明文传入,后端加密存储
api_key_header: str = "Authorization"
health_check_url: str = ""
tags: str = ""
class ApiUpdate(BaseModel):
category_id: Optional[int] = None
name: Optional[str] = None
description: Optional[str] = None
base_url: Optional[str] = None
doc_url: Optional[str] = None
auth_type: Optional[str] = None
api_key: Optional[str] = None
api_key_header: Optional[str] = None
health_check_url: Optional[str] = None
tags: Optional[str] = None
is_active: Optional[bool] = None
class ApiTestRequest(BaseModel):
method: str = "GET"
path: str = ""
body: str = ""
headers: dict = {}
# ========== 密码认证接口 ==========
@router.post("/auth")
def hub_auth(
data: HubAuthRequest,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
"""验证Hub访问密码"""
stored = _get_hub_password(db)
if not stored:
raise HTTPException(status_code=400, detail="管理员尚未设置访问密码")
if _hash_password(data.password) != stored:
raise HTTPException(status_code=403, detail="密码错误")
token = _create_hub_token(user.id, _hub_password_version(db))
return {"hub_token": token, "expires_in": 7200}
@router.get("/check-password")
def check_password_set(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
"""检查是否已设置访问密码"""
stored = _get_hub_password(db)
return {"has_password": bool(stored)}
@router.put("/admin/password")
def set_hub_password(
data: HubAuthRequest,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""管理员设置Hub访问密码"""
if len(data.password) < 4:
raise HTTPException(status_code=400, detail="密码至少4位")
hashed = _hash_password(data.password)
cfg = db.query(SystemConfig).filter(SystemConfig.key == "api_hub_password").first()
if cfg:
cfg.value = hashed
else:
cfg = SystemConfig(key="api_hub_password", value=hashed, description="API Hub访问密码")
db.add(cfg)
db.commit()
return {"message": "密码设置成功"}
@router.get("/admin/password-status")
def get_password_status(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user),
):
"""管理员查看密码是否已设置"""
stored = _get_hub_password(db)
return {"has_password": bool(stored)}
# ========== 分类管理 ==========
@router.get("/categories")
def list_categories(
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
cats = db.query(SharedApiCategory).order_by(SharedApiCategory.sort_order, SharedApiCategory.id).all()
return [
{"id": c.id, "name": c.name, "icon": c.icon, "sort_order": c.sort_order, "is_active": c.is_active,
"api_count": db.query(sa_func.count(SharedApi.id)).filter(SharedApi.category_id == c.id).scalar() or 0}
for c in cats
]
@router.post("/categories")
def create_category(
data: CategoryCreate,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
existing = db.query(SharedApiCategory).filter(SharedApiCategory.name == data.name).first()
if existing:
raise HTTPException(status_code=400, detail="分类名称已存在")
max_order = db.query(sa_func.max(SharedApiCategory.sort_order)).scalar() or 0
cat = SharedApiCategory(name=data.name, icon=data.icon, sort_order=max_order + 1)
db.add(cat)
db.commit()
db.refresh(cat)
return {"id": cat.id, "name": cat.name, "icon": cat.icon}
@router.put("/categories/{cat_id}")
def update_category(
cat_id: int,
data: CategoryUpdate,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
cat = db.query(SharedApiCategory).filter(SharedApiCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
if data.name is not None:
cat.name = data.name
if data.icon is not None:
cat.icon = data.icon
if data.sort_order is not None:
cat.sort_order = data.sort_order
if data.is_active is not None:
cat.is_active = data.is_active
db.commit()
return {"message": "更新成功"}
@router.delete("/categories/{cat_id}")
def delete_category(
cat_id: int,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
cat = db.query(SharedApiCategory).filter(SharedApiCategory.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="分类不存在")
# 将该分类下的API设为未分类
db.query(SharedApi).filter(SharedApi.category_id == cat_id).update({SharedApi.category_id: None})
db.delete(cat)
db.commit()
return {"message": "删除成功"}
# ========== API CRUD ==========
def _api_to_dict(api, db=None):
d = {
"id": api.id, "category_id": api.category_id,
"name": api.name, "description": api.description,
"base_url": api.base_url, "doc_url": api.doc_url,
"auth_type": api.auth_type,
"api_key_masked": mask_key(api.api_key_encrypted),
"api_key_plain": decrypt_key(api.api_key_encrypted) if api.api_key_encrypted else "",
"has_api_key": bool(api.api_key_encrypted),
"api_key_header": api.api_key_header,
"health_check_url": api.health_check_url,
"last_check_time": api.last_check_time.isoformat() if api.last_check_time else None,
"last_check_status": api.last_check_status,
"added_by": api.added_by, "tags": api.tags,
"call_count": api.call_count, "is_active": api.is_active,
"created_at": api.created_at.isoformat() if api.created_at else None,
"updated_at": api.updated_at.isoformat() if api.updated_at else None,
}
return d
@router.get("/list")
def list_apis(
keyword: Optional[str] = None,
category_id: Optional[int] = None,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
query = db.query(SharedApi).filter(SharedApi.is_active == True)
if keyword:
kw = f"%{keyword}%"
query = query.filter(
(SharedApi.name.like(kw)) | (SharedApi.description.like(kw)) | (SharedApi.tags.like(kw))
)
if category_id is not None:
query = query.filter(SharedApi.category_id == category_id)
apis = query.order_by(SharedApi.call_count.desc(), SharedApi.id.desc()).all()
return {"items": [_api_to_dict(a) for a in apis], "total": len(apis)}
@router.post("/")
def create_api(
data: ApiCreate,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
api = SharedApi(
category_id=data.category_id, name=data.name, description=data.description,
base_url=data.base_url, doc_url=data.doc_url,
auth_type=data.auth_type,
api_key_encrypted=encrypt_key(data.api_key) if data.api_key else "",
api_key_header=data.api_key_header,
health_check_url=data.health_check_url,
tags=data.tags, added_by=user.id,
)
db.add(api)
db.commit()
db.refresh(api)
return _api_to_dict(api)
@router.put("/{api_id}")
def update_api(
api_id: int,
data: ApiUpdate,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
if not api:
raise HTTPException(status_code=404, detail="API不存在")
if data.category_id is not None:
api.category_id = data.category_id
if data.name is not None:
api.name = data.name
if data.description is not None:
api.description = data.description
if data.base_url is not None:
api.base_url = data.base_url
if data.doc_url is not None:
api.doc_url = data.doc_url
if data.auth_type is not None:
api.auth_type = data.auth_type
if data.api_key is not None and data.api_key != "":
api.api_key_encrypted = encrypt_key(data.api_key)
if data.api_key_header is not None:
api.api_key_header = data.api_key_header
if data.health_check_url is not None:
api.health_check_url = data.health_check_url
if data.tags is not None:
api.tags = data.tags
if data.is_active is not None:
api.is_active = data.is_active
db.commit()
db.refresh(api)
return _api_to_dict(api)
@router.delete("/{api_id}")
def delete_api(
api_id: int,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
if not api:
raise HTTPException(status_code=404, detail="API不存在")
db.query(SharedApiLog).filter(SharedApiLog.api_id == api_id).delete()
db.delete(api)
db.commit()
return {"message": "删除成功"}
# ========== API测试 ==========
@router.post("/{api_id}/test")
async def test_api(
api_id: int,
data: ApiTestRequest,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
"""在线测试API后端代理请求"""
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
if not api:
raise HTTPException(status_code=404, detail="API不存在")
url = api.base_url.rstrip("/")
if data.path:
url = url + "/" + data.path.lstrip("/")
headers = dict(data.headers) if data.headers else {}
# 注入认证信息
if api.auth_type != "none" and api.api_key_encrypted:
key = decrypt_key(api.api_key_encrypted)
if key:
if api.auth_type == "bearer":
headers[api.api_key_header] = f"Bearer {key}"
elif api.auth_type == "api_key":
headers[api.api_key_header] = key
elif api.auth_type == "basic":
import base64
headers["Authorization"] = f"Basic {base64.b64encode(key.encode()).decode()}"
import httpx
start = time.time()
try:
async with httpx.AsyncClient(timeout=15) as client:
if data.method.upper() == "POST":
resp = await client.post(url, headers=headers, content=data.body or None)
elif data.method.upper() == "PUT":
resp = await client.put(url, headers=headers, content=data.body or None)
elif data.method.upper() == "DELETE":
resp = await client.delete(url, headers=headers)
else:
resp = await client.get(url, headers=headers)
elapsed = int((time.time() - start) * 1000)
# 记录日志
log = SharedApiLog(
api_id=api_id, user_id=user.id, action="test",
request_url=url, response_status=resp.status_code, response_time_ms=elapsed,
)
db.add(log)
api.call_count = (api.call_count or 0) + 1
db.commit()
# 限制返回体大小
body = resp.text[:5000] if len(resp.text) > 5000 else resp.text
return {
"status_code": resp.status_code,
"response_time_ms": elapsed,
"headers": dict(resp.headers),
"body": body,
}
except Exception as e:
elapsed = int((time.time() - start) * 1000)
log = SharedApiLog(
api_id=api_id, user_id=user.id, action="test",
request_url=url, response_status=0, response_time_ms=elapsed,
)
db.add(log)
db.commit()
return {"status_code": 0, "response_time_ms": elapsed, "headers": {}, "body": str(e)}
@router.post("/{api_id}/health-check")
async def health_check(
api_id: int,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
"""健康检查"""
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
if not api:
raise HTTPException(status_code=404, detail="API不存在")
check_url = api.health_check_url or api.base_url
import httpx
start = time.time()
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(check_url)
elapsed = int((time.time() - start) * 1000)
status = "ok" if resp.status_code < 400 else "error"
except Exception:
elapsed = int((time.time() - start) * 1000)
status = "error"
api.last_check_time = datetime.utcnow()
api.last_check_status = status
log = SharedApiLog(
api_id=api_id, user_id=user.id, action="health_check",
request_url=check_url, response_status=resp.status_code if status == "ok" else 0,
response_time_ms=elapsed,
)
db.add(log)
db.commit()
return {"status": status, "response_time_ms": elapsed}
# ========== 日志与统计 ==========
@router.get("/{api_id}/logs")
def get_api_logs(
api_id: int,
limit: int = 20,
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
logs = (
db.query(SharedApiLog)
.filter(SharedApiLog.api_id == api_id)
.order_by(SharedApiLog.id.desc())
.limit(limit)
.all()
)
return [
{
"id": l.id, "action": l.action, "request_url": l.request_url,
"response_status": l.response_status, "response_time_ms": l.response_time_ms,
"user_id": l.user_id,
"created_at": l.created_at.isoformat() if l.created_at else None,
}
for l in logs
]
@router.get("/stats")
def get_stats(
db: Session = Depends(get_db),
user: User = Depends(verify_hub_access),
):
total_apis = db.query(sa_func.count(SharedApi.id)).filter(SharedApi.is_active == True).scalar() or 0
total_calls = db.query(sa_func.sum(SharedApi.call_count)).scalar() or 0
total_categories = db.query(sa_func.count(SharedApiCategory.id)).scalar() or 0
healthy = db.query(sa_func.count(SharedApi.id)).filter(SharedApi.last_check_status == "ok").scalar() or 0
return {
"total_apis": total_apis,
"total_calls": total_calls,
"total_categories": total_categories,
"healthy_count": healthy,
}

206
backend/routers/upload.py Normal file
View File

@@ -0,0 +1,206 @@
"""通用文件上传路由 - 腾讯云COS"""
import uuid
import os
from datetime import datetime
from fastapi import APIRouter, UploadFile, File, Depends, HTTPException
from sqlalchemy.orm import Session
from routers.auth import get_current_user
from models.user import User
from models.system_config import SystemConfig
from database import get_db
from config import MAX_UPLOAD_SIZE, MAX_ATTACHMENT_SIZE, UPLOAD_DIR
from models.attachment import Attachment
router = APIRouter()
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/gif", "image/webp"]
ALLOWED_ATTACHMENT_TYPES = [
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/zip",
"application/x-zip-compressed",
"application/x-rar-compressed",
"application/vnd.rar",
]
ALLOWED_ATTACHMENT_EXTS = {
"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "zip", "rar",
}
def _get_cos_config(db: Session) -> dict:
"""从数据库读取COS配置"""
keys = ["cos_secret_id", "cos_secret_key", "cos_bucket", "cos_region", "cos_custom_domain"]
config = {}
for k in keys:
row = db.query(SystemConfig).filter(SystemConfig.key == k).first()
config[k] = row.value if row else ""
return config
def _get_cos_client(db: Session):
"""获取COS客户端实例未配置则返回None"""
config = _get_cos_config(db)
secret_id = config.get("cos_secret_id", "")
secret_key = config.get("cos_secret_key", "")
bucket = config.get("cos_bucket", "")
region = config.get("cos_region", "")
if not all([secret_id, secret_key, bucket, region]):
return None, config
try:
from qcloud_cos import CosConfig, CosS3Client
cos_config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
client = CosS3Client(cos_config)
return client, config
except ImportError:
return None, config
def _build_cos_url(config: dict, object_key: str) -> str:
"""构建COS访问URL"""
custom_domain = config.get("cos_custom_domain", "")
if custom_domain:
return f"https://{custom_domain}/{object_key}"
bucket = config.get("cos_bucket", "")
region = config.get("cos_region", "")
return f"https://{bucket}.cos.{region}.myqcloud.com/{object_key}"
@router.post("/image")
async def upload_image(
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""上传图片优先OSS未配置则本地存储"""
# 验证文件类型
if file.content_type not in ALLOWED_IMAGE_TYPES:
raise HTTPException(status_code=400, detail="不支持的文件类型,仅支持 JPG/PNG/GIF/WEBP")
# 读取文件内容
content = await file.read()
# 验证文件大小
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=400, detail=f"文件过大,最大支持 {MAX_UPLOAD_SIZE // (1024*1024)}MB")
# 生成唯一文件名
ext = file.filename.split(".")[-1].lower() if "." in file.filename else "png"
date_prefix = datetime.now().strftime("%Y/%m")
filename = f"{uuid.uuid4().hex}.{ext}"
object_key = f"images/{date_prefix}/{filename}"
# 尝试COS上传从数据库读取配置
client, config = _get_cos_client(db)
if client:
try:
client.put_object(
Bucket=config.get("cos_bucket", ""),
Body=content,
Key=object_key,
ContentType=file.content_type,
)
url = _build_cos_url(config, object_key)
return {"url": url, "storage": "cos"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"COS上传失败: {str(e)}")
else:
# 降级到本地存储
os.makedirs(UPLOAD_DIR, exist_ok=True)
filepath = os.path.join(UPLOAD_DIR, filename)
with open(filepath, "wb") as f:
f.write(content)
return {"url": f"/uploads/{filename}", "storage": "local"}
@router.post("/attachment")
async def upload_attachment(
file: UploadFile = File(...),
post_id: int = 0,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""上传附件到COS支持 PDF/Word/Excel/PPT/ZIP/RAR"""
# 验证文件扩展名
ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if ext not in ALLOWED_ATTACHMENT_EXTS:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型,仅支持: {', '.join(sorted(ALLOWED_ATTACHMENT_EXTS))}",
)
# 读取文件内容
content = await file.read()
if len(content) > MAX_ATTACHMENT_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大,最大支持 {MAX_ATTACHMENT_SIZE // (1024*1024)}MB",
)
# 生成唯一文件名
date_prefix = datetime.now().strftime("%Y/%m")
unique_name = f"{uuid.uuid4().hex}.{ext}"
object_key = f"attachments/{date_prefix}/{unique_name}"
# 上传到COS
client, config = _get_cos_client(db)
if not client:
raise HTTPException(status_code=500, detail="对象存储未配置,无法上传附件")
try:
client.put_object(
Bucket=config.get("cos_bucket", ""),
Body=content,
Key=object_key,
ContentType=file.content_type or "application/octet-stream",
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"COS上传失败: {str(e)}")
url = _build_cos_url(config, object_key)
# 写入数据库
attachment = Attachment(
post_id=post_id if post_id else None,
user_id=current_user.id,
filename=file.filename,
storage_key=object_key,
url=url,
file_size=len(content),
file_type=file.content_type or "application/octet-stream",
)
db.add(attachment)
db.commit()
db.refresh(attachment)
return {
"id": attachment.id,
"filename": attachment.filename,
"url": attachment.url,
"file_size": attachment.file_size,
"file_type": attachment.file_type,
}
@router.put("/attachment/{attachment_id}/post")
async def update_attachment_post(
attachment_id: int,
post_id: int = 0,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""将附件关联到帖子(新建帖子后回填 post_id"""
attachment = db.query(Attachment).filter(
Attachment.id == attachment_id,
Attachment.user_id == current_user.id,
).first()
if not attachment:
raise HTTPException(status_code=404, detail="附件不存在")
attachment.post_id = post_id
db.commit()
return {"message": "ok"}

214
backend/routers/users.py Normal file
View File

@@ -0,0 +1,214 @@
"""用户主页和关注系统路由"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import func
from database import get_db
from models.user import User
from models.post import Post
from models.follow import Follow
from models.like import Collect
from models.notification import Notification
from routers.auth import get_current_user
router = APIRouter()
@router.get("/{user_id}")
def get_user_profile(
user_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取用户主页信息"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
post_count = db.query(func.count(Post.id)).filter(Post.user_id == user_id, Post.is_public == True).scalar()
follower_count = db.query(func.count(Follow.id)).filter(Follow.following_id == user_id).scalar()
following_count = db.query(func.count(Follow.id)).filter(Follow.follower_id == user_id).scalar()
is_following = db.query(Follow).filter(
Follow.follower_id == current_user.id, Follow.following_id == user_id
).first() is not None
return {
"id": user.id,
"username": user.username,
"email": user.email,
"avatar": user.avatar,
"created_at": user.created_at,
"post_count": post_count,
"follower_count": follower_count,
"following_count": following_count,
"is_following": is_following,
"is_self": current_user.id == user_id,
}
@router.post("/{user_id}/follow")
def toggle_follow(
user_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""关注/取消关注"""
if current_user.id == user_id:
raise HTTPException(status_code=400, detail="不能关注自己")
target = db.query(User).filter(User.id == user_id).first()
if not target:
raise HTTPException(status_code=404, detail="用户不存在")
existing = db.query(Follow).filter(
Follow.follower_id == current_user.id, Follow.following_id == user_id
).first()
if existing:
db.delete(existing)
db.commit()
return {"followed": False}
else:
follow = Follow(follower_id=current_user.id, following_id=user_id)
db.add(follow)
# 创建通知
notif = Notification(
user_id=user_id,
type="follow",
content=f"{current_user.username} 关注了你",
from_user_id=current_user.id,
related_id=current_user.id,
)
db.add(notif)
db.commit()
return {"followed": True}
@router.get("/{user_id}/posts")
def get_user_posts(
user_id: int,
page: int = 1,
page_size: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取用户的帖子列表"""
query = db.query(Post).filter(Post.user_id == user_id)
if user_id != current_user.id:
query = query.filter(Post.is_public == True)
posts = query.order_by(Post.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
user = db.query(User).filter(User.id == user_id).first()
username = user.username if user else "未知"
avatar = user.avatar if user else ""
return [
{
"id": p.id, "title": p.title, "content": p.content[:200],
"category": p.category, "tags": p.tags,
"view_count": p.view_count, "like_count": p.like_count,
"comment_count": p.comment_count, "collect_count": p.collect_count,
"created_at": p.created_at, "updated_at": p.updated_at,
"author": {"id": user_id, "username": username, "avatar": avatar},
}
for p in posts
]
@router.get("/{user_id}/followers")
def get_followers(
user_id: int,
page: int = 1,
page_size: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取粉丝列表"""
follows = (
db.query(Follow)
.filter(Follow.following_id == user_id)
.order_by(Follow.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
user_ids = [f.follower_id for f in follows]
users = db.query(User).filter(User.id.in_(user_ids)).all() if user_ids else []
user_map = {u.id: u for u in users}
return [
{
"id": uid,
"username": user_map[uid].username if uid in user_map else "",
"avatar": user_map[uid].avatar if uid in user_map else "",
}
for uid in user_ids
]
@router.get("/{user_id}/following")
def get_following(
user_id: int,
page: int = 1,
page_size: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取关注列表"""
follows = (
db.query(Follow)
.filter(Follow.follower_id == user_id)
.order_by(Follow.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
user_ids = [f.following_id for f in follows]
users = db.query(User).filter(User.id.in_(user_ids)).all() if user_ids else []
user_map = {u.id: u for u in users}
return [
{
"id": uid,
"username": user_map[uid].username if uid in user_map else "",
"avatar": user_map[uid].avatar if uid in user_map else "",
}
for uid in user_ids
]
@router.get("/{user_id}/collects")
def get_user_collects(
user_id: int,
page: int = 1,
page_size: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取用户收藏的帖子"""
collects = (
db.query(Collect)
.filter(Collect.user_id == user_id)
.order_by(Collect.created_at.desc())
.offset((page - 1) * page_size).limit(page_size).all()
)
post_ids = [c.post_id for c in collects]
if not post_ids:
return []
posts = db.query(Post).filter(Post.id.in_(post_ids)).all()
post_map = {p.id: p for p in posts}
author_ids = list(set(p.user_id for p in posts))
authors = db.query(User).filter(User.id.in_(author_ids)).all()
author_map = {a.id: a for a in authors}
result = []
for pid in post_ids:
p = post_map.get(pid)
if not p:
continue
a = author_map.get(p.user_id)
result.append({
"id": p.id, "title": p.title, "content": p.content[:200],
"category": p.category, "tags": p.tags,
"view_count": p.view_count, "like_count": p.like_count,
"comment_count": p.comment_count, "collect_count": p.collect_count,
"created_at": p.created_at,
"author": {"id": p.user_id, "username": a.username if a else "", "avatar": a.avatar if a else ""},
})
return result

View File

@@ -0,0 +1,274 @@
"""联网搜索路由 - 使用豆包大模型 + 火山方舟 web_search"""
import json
import httpx
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db
from models.user import User
from models.conversation import Conversation, Message
from schemas.conversation import ConversationResponse, ConversationDetail, MessageResponse
from routers.auth import get_current_user
from config import ARK_API_KEY, ARK_ENDPOINT, ARK_BASE_URL
router = APIRouter()
WEB_SEARCH_SYSTEM_PROMPT = """你是一个智能联网搜索助手。你可以通过联网搜索获取最新信息来回答用户的问题。
回答要求:
1. 基于搜索到的最新信息给出准确、详细的回答
2. 使用清晰的 Markdown 格式组织内容
3. 如果涉及时效性信息,注明信息的时间
4. 对搜索结果进行整合和总结,而非简单罗列
5. 如果搜索结果不足以回答问题,诚实告知并给出建议"""
class WebSearchRequest(BaseModel):
conversation_id: Optional[int] = None
content: str
model_config_id: Optional[int] = None
@router.post("/search")
async def web_search(
request: WebSearchRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""联网搜索 - 流式输出"""
# 获取模型配置:优先用指定的,否则用 .env 的 ARK 配置
from services.ai_service import _get_db_model_config
ark_api_key = ARK_API_KEY
ark_endpoint = ARK_ENDPOINT
ark_base_url = ARK_BASE_URL
search_count = 5
if request.model_config_id:
cfg = _get_db_model_config("", request.model_config_id)
if cfg:
ark_api_key = cfg["api_key"]
ark_endpoint = cfg["model"]
ark_base_url = cfg["base_url"] or ARK_BASE_URL
search_count = cfg.get("web_search_count", 5)
if not ark_api_key:
raise HTTPException(status_code=500, detail="未配置火山方舟 API Key")
# 创建或获取对话
if request.conversation_id:
conv = db.query(Conversation).filter(
Conversation.id == request.conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
else:
conv = Conversation(
user_id=current_user.id,
title=request.content[:50] if request.content else "新搜索",
type="web_search",
)
db.add(conv)
db.commit()
db.refresh(conv)
# 保存用户消息
user_msg = Message(
conversation_id=conv.id,
role="user",
content=request.content,
)
db.add(user_msg)
db.commit()
# 构建历史消息
history_msgs = (
db.query(Message)
.filter(Message.conversation_id == conv.id)
.order_by(Message.created_at.asc())
.all()
)
messages = [{"role": "system", "content": WEB_SEARCH_SYSTEM_PROMPT}]
for msg in history_msgs:
messages.append({"role": msg.role, "content": msg.content})
# 流式调用火山方舟 API
url = f"{ark_base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {ark_api_key}",
"Content-Type": "application/json",
}
async def generate():
full_response = ""
try:
payload_with_search = {
"model": ark_endpoint,
"messages": messages,
"stream": True,
"tools": [
{
"type": "web_search",
"web_search": {"enable": True, "search_result_count": max(1, min(50, search_count))},
}
],
}
payload_without_search = {
"model": ark_endpoint,
"messages": messages,
"stream": True,
}
# 先尝试带联网搜索调用
use_fallback = False
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload_with_search, headers=headers) as resp:
if resp.status_code == 400:
# 模型不支持 web_search tools降级到普通调用
await resp.aread()
use_fallback = True
elif resp.status_code != 200:
error_body = await resp.aread()
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
full_response = error_msg
else:
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
except json.JSONDecodeError:
pass
# 降级:不带 web_search tools 重试
if use_fallback:
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload_without_search, headers=headers) as resp:
if resp.status_code != 200:
error_body = await resp.aread()
error_msg = f"API 调用失败 ({resp.status_code}): {error_body.decode()}"
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
full_response = error_msg
else:
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
yield f"data: {json.dumps({'content': content, 'done': False})}\n\n"
except json.JSONDecodeError:
pass
except Exception as e:
error_msg = f"联网搜索出错: {str(e)}"
if not full_response:
full_response = error_msg
yield f"data: {json.dumps({'content': error_msg, 'done': False})}\n\n"
# 保存AI回复
if full_response:
ai_msg = Message(
conversation_id=conv.id,
role="assistant",
content=full_response,
)
db.add(ai_msg)
db.commit()
yield f"data: {json.dumps({'content': '', 'done': True, 'conversation_id': conv.id})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@router.get("/conversations", response_model=List[ConversationResponse])
def get_conversations(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取联网搜索对话列表"""
conversations = (
db.query(Conversation)
.filter(Conversation.user_id == current_user.id, Conversation.type == "web_search")
.order_by(Conversation.updated_at.desc())
.all()
)
return [ConversationResponse.model_validate(c) for c in conversations]
@router.get("/conversations/{conversation_id}", response_model=ConversationDetail)
def get_conversation_detail(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取对话详情"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
msgs = (
db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.all()
)
result = ConversationDetail.model_validate(conv)
result.messages = [MessageResponse.model_validate(m) for m in msgs]
return result
@router.delete("/conversations/{conversation_id}")
def delete_conversation(
conversation_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除对话"""
conv = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
).first()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
db.query(Message).filter(Message.conversation_id == conversation_id).delete()
db.delete(conv)
db.commit()
return {"message": "删除成功"}

View File

View File

@@ -0,0 +1,62 @@
"""AI模型配置Schema"""
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
class AIModelCreate(BaseModel):
provider: str
provider_name: str = ""
model_id: str
model_name: str = ""
api_key: str = ""
base_url: str = ""
task_type: str = ""
is_enabled: bool = True
is_default: bool = False
web_search_enabled: bool = False
web_search_count: int = 5 # 联网搜索结果条数1-50
description: str = ""
class AIModelUpdate(BaseModel):
provider_name: Optional[str] = None
model_id: Optional[str] = None
model_name: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
task_type: Optional[str] = None
is_enabled: Optional[bool] = None
is_default: Optional[bool] = None
web_search_enabled: Optional[bool] = None
web_search_count: Optional[int] = None
description: Optional[str] = None
class AIModelResponse(BaseModel):
id: int
provider: str
provider_name: str = ""
model_id: str
model_name: str = ""
api_key_masked: str = "" # 脱敏后的API Key
base_url: str = ""
task_type: str = ""
is_enabled: bool = True
is_default: bool = False
web_search_enabled: bool = False
web_search_count: int = 5
description: str = ""
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProviderInfo(BaseModel):
"""服务商信息"""
provider: str
name: str
models: List[dict]
default_base_url: str = ""

View File

@@ -0,0 +1,38 @@
"""网站收藏Schema"""
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
class BookmarkCreate(BaseModel):
name: str
url: str
icon: str = ""
class BookmarkUpdate(BaseModel):
name: Optional[str] = None
url: Optional[str] = None
icon: Optional[str] = None
sort_order: Optional[int] = None
class BookmarkResponse(BaseModel):
id: int
name: str
url: str
icon: str = ""
sort_order: int = 0
created_at: datetime
class Config:
from_attributes = True
class ReorderItem(BaseModel):
id: int
sort_order: int
class ReorderRequest(BaseModel):
items: List[ReorderItem]

View File

@@ -0,0 +1,55 @@
"""对话相关Schema"""
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
class MessageCreate(BaseModel):
content: str
image_urls: List[str] = []
class MessageResponse(BaseModel):
id: int
conversation_id: int
role: str
content: str
image_urls: str = ""
created_at: datetime
class Config:
from_attributes = True
class ConversationCreate(BaseModel):
type: str # requirement / architecture
title: str = "新对话"
class ConversationResponse(BaseModel):
id: int
user_id: int
title: str
type: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ConversationDetail(ConversationResponse):
messages: List[MessageResponse] = []
class RequirementAnalyzeRequest(BaseModel):
conversation_id: Optional[int] = None
content: str
image_urls: List[str] = []
model_config_id: Optional[int] = None
class ArchitectureRequest(BaseModel):
conversation_id: Optional[int] = None
content: str
model_config_id: Optional[int] = None

69
backend/schemas/post.py Normal file
View File

@@ -0,0 +1,69 @@
"""帖子相关Schema"""
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
class PostCreate(BaseModel):
title: str
content: str
category: str = ""
tags: List[str] = []
is_public: bool = True
is_draft: bool = False
class PostUpdate(BaseModel):
title: Optional[str] = None
content: Optional[str] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
is_public: Optional[bool] = None
is_draft: Optional[bool] = None
class PostResponse(BaseModel):
id: int
user_id: int
title: str
content: str
category: str = ""
tags: str = ""
is_public: bool = True
is_draft: bool = False
view_count: int = 0
like_count: int = 0
collect_count: int = 0
comment_count: int = 0
created_at: datetime
updated_at: datetime
# 额外字段(查询时填充)
author_name: str = ""
is_liked: bool = False
is_collected: bool = False
class Config:
from_attributes = True
class PostListResponse(BaseModel):
items: List[PostResponse]
total: int
page: int
page_size: int
class CommentCreate(BaseModel):
content: str
class CommentResponse(BaseModel):
id: int
post_id: int
user_id: int
content: str
created_at: datetime
author_name: str = ""
class Config:
from_attributes = True

43
backend/schemas/user.py Normal file
View File

@@ -0,0 +1,43 @@
"""用户相关Schema"""
from pydantic import BaseModel, EmailStr
from datetime import datetime
from typing import Optional
class UserRegister(BaseModel):
username: str
email: str
password: str
class UserLogin(BaseModel):
username: str
password: str
class UserResponse(BaseModel):
id: int
username: str
email: str
avatar: str = ""
is_admin: bool = False
is_banned: bool = False
is_approved: bool = False
created_at: datetime
class Config:
from_attributes = True
class UserUpdate(BaseModel):
username: Optional[str] = None
email: Optional[str] = None
avatar: Optional[str] = None
old_password: Optional[str] = None
new_password: Optional[str] = None
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
user: UserResponse

View File

View File

@@ -0,0 +1,429 @@
"""统一大模型调用服务 - 支持多模型路由和流式输出"""
import json
import httpx
import asyncio
from typing import AsyncGenerator, List, Optional, Union
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from config import (
MODEL_CONFIG,
OPENAI_API_KEY,
ANTHROPIC_API_KEY,
GOOGLE_API_KEY,
DEEPSEEK_API_KEY,
ARK_API_KEY,
ARK_BASE_URL,
)
def _get_db_model_config(task_type: str, model_config_id: int = None):
"""从数据库获取指定任务类型的默认模型配置,或指定 ID 的模型"""
try:
from database import SessionLocal
from models.ai_model import AIModelConfig
db = SessionLocal()
try:
# 如果指定了模型 ID直接用该模型
if model_config_id:
model = db.query(AIModelConfig).filter(
AIModelConfig.id == model_config_id,
AIModelConfig.is_enabled == True,
).first()
if model and model.api_key:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
# 否则找默认模型
model = db.query(AIModelConfig).filter(
AIModelConfig.task_type == task_type,
AIModelConfig.is_default == True,
AIModelConfig.is_enabled == True,
).first()
if model and model.api_key:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
# 没有默认的找任意一个启用且有Key的
model = db.query(AIModelConfig).filter(
AIModelConfig.task_type == task_type,
AIModelConfig.is_enabled == True,
AIModelConfig.api_key != "",
).first()
if model:
return {
"provider": model.provider,
"model": model.model_id,
"api_key": model.api_key,
"base_url": model.base_url,
"web_search_enabled": model.web_search_enabled,
"web_search_count": model.web_search_count or 5,
}
finally:
db.close()
except Exception:
pass
return None
class AIService:
"""统一AI服务根据任务类型路由到不同大模型"""
def __init__(self):
# OpenAI客户端也用于DeepSeek等兼容API
if OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
else:
self.openai_client = None
# Anthropic客户端
if ANTHROPIC_API_KEY:
self.anthropic_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
else:
self.anthropic_client = None
# DeepSeek客户端兼容OpenAI接口
if DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=DEEPSEEK_API_KEY,
base_url="https://api.deepseek.com/v1",
)
else:
self.deepseek_client = None
def _get_client_for_provider(self, provider: str, api_key: str, base_url: str = ""):
"""根据provider动态创建客户端"""
if provider == "anthropic":
return AsyncAnthropic(api_key=api_key)
# openai/deepseek/google 都用OpenAI兼容接口
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
return AsyncOpenAI(**kwargs)
async def chat(
self,
task_type: str,
messages: List[dict],
system_prompt: str = "",
stream: bool = False,
model_config_id: int = None,
) -> Union[str, AsyncGenerator[str, None]]:
"""
统一对话接口
参数:
task_type: 任务类型 (multimodal/reasoning/lightweight)
messages: 消息列表 [{"role": "user", "content": "..."}]
system_prompt: 系统提示词
stream: 是否流式输出
model_config_id: 指定模型配置ID可选不传则用默认
"""
# 优先从数据库读取模型配置
db_config = _get_db_model_config(task_type, model_config_id)
if db_config:
provider = db_config["provider"]
model = db_config["model"]
api_key = db_config["api_key"]
base_url = db_config["base_url"]
web_search = db_config.get("web_search_enabled", False)
web_search_count = db_config.get("web_search_count", 5)
# 火山方舟/豆包 + 联网搜索(开启后自动使用,失败则降级到普通调用)
if provider == "ark" and web_search:
try:
return await self._chat_ark_web_search(api_key, base_url, model, messages, system_prompt, stream, web_search_count)
except Exception:
# 联网搜索调用失败,降级到普通调用
pass
# 火山方舟/豆包 不带联网搜索OpenAI 兼容接口)
if provider == "ark":
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
else:
kwargs["base_url"] = ARK_BASE_URL
client = AsyncOpenAI(**kwargs)
return await self._chat_openai(client, model, messages, system_prompt, stream)
if provider == "anthropic":
client = AsyncAnthropic(api_key=api_key)
return await self._chat_anthropic_with_client(client, model, messages, system_prompt, stream)
else:
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
client = AsyncOpenAI(**kwargs)
# deepseek-reasoner 需要特殊处理
if model == "deepseek-reasoner":
return await self._chat_deepseek_reasoner(client, model, messages, system_prompt, stream)
return await self._chat_openai(client, model, messages, system_prompt, stream)
# 回退到 .env 配置
config = MODEL_CONFIG.get(task_type, MODEL_CONFIG["reasoning"])
provider = config["provider"]
model = config["model"]
if provider == "anthropic" and self.anthropic_client:
return await self._chat_anthropic(model, messages, system_prompt, stream)
elif provider == "openai" and self.openai_client:
return await self._chat_openai(self.openai_client, model, messages, system_prompt, stream)
elif provider == "deepseek" and self.deepseek_client:
return await self._chat_openai(self.deepseek_client, model, messages, system_prompt, stream)
elif provider == "google":
if GOOGLE_API_KEY:
google_client = AsyncOpenAI(
api_key=GOOGLE_API_KEY,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
return await self._chat_openai(google_client, model, messages, system_prompt, stream)
# 降级
if self.deepseek_client:
return await self._chat_openai(self.deepseek_client, "deepseek-chat", messages, system_prompt, stream)
if self.openai_client:
return await self._chat_openai(self.openai_client, "gpt-4o-mini", messages, system_prompt, stream)
if self.anthropic_client:
return await self._chat_anthropic("claude-sonnet-4-20250514", messages, system_prompt, stream)
return "未配置任何AI模型请到「模型管理」页面配置模型和API Key。"
async def _chat_ark_web_search(
self, api_key: str, base_url: str, model: str,
messages: List[dict], system_prompt: str, stream: bool,
search_count: int = 5,
) -> Union[str, AsyncGenerator[str, None]]:
"""火山方舟 + 联网搜索(使用 httpx 直接调用,因 web_search 是非标准 tools 类型)"""
url = f"{base_url or ARK_BASE_URL}/chat/completions"
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
# 限制搜索条数范围 1-50
search_count = max(1, min(50, search_count or 5))
payload = {
"model": model,
"messages": full_messages,
"stream": stream,
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": search_count}}],
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if not stream:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code != 200:
# 联网搜索调用失败,抛出异常以便降级到普通调用
raise Exception(f"API调用失败 ({resp.status_code}): {resp.text[:200]}")
data = resp.json()
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
else:
# 流式调用时,先发一个预检测请求确认模型支持 web_search
async with httpx.AsyncClient(timeout=10.0) as client:
test_payload = {
"model": model,
"messages": [{"role": "user", "content": "test"}],
"stream": False,
"max_tokens": 1,
"tools": [{"type": "web_search", "web_search": {"enable": True, "search_result_count": 1}}],
}
try:
resp = await client.post(url, json=test_payload, headers=headers)
if resp.status_code == 400:
# 模型不支持 web_search抛出异常以便降级
raise Exception("模型不支持联网搜索")
except httpx.TimeoutException:
pass # 超时不影响,继续尝试流式调用
return self._stream_ark_web_search(url, payload, headers)
async def _stream_ark_web_search(
self, url: str, payload: dict, headers: dict,
) -> AsyncGenerator[str, None]:
"""火山方舟联网搜索流式输出"""
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, json=payload, headers=headers) as resp:
if resp.status_code != 200:
error_body = await resp.aread()
yield f"API调用失败 ({resp.status_code}): {error_body.decode()[:200]}"
return
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
pass
async def _chat_openai(
self, client: AsyncOpenAI, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""OpenAI兼容接口调用"""
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
if stream:
return self._stream_openai(client, model, full_messages)
else:
response = await client.chat.completions.create(
model=model,
messages=full_messages,
temperature=0.7,
max_tokens=4096,
)
return response.choices[0].message.content
async def _stream_openai(
self, client: AsyncOpenAI, model: str, messages: List[dict],
) -> AsyncGenerator[str, None]:
"""OpenAI流式输出"""
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.7,
max_tokens=4096,
stream=True,
)
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def _chat_deepseek_reasoner(
self, client: AsyncOpenAI, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""DeepSeek Reasoner (思考模式) 专用调用
注意deepseek-reasoner 不支持 temperature/top_p/system 等参数
输出包含 reasoning_content思考过程和 content最终回答
"""
# reasoner 不支持 system role将 system prompt 合并到第一条用户消息
full_messages = []
for msg in messages:
full_messages.append(msg)
if system_prompt and full_messages:
first_user = None
for m in full_messages:
if m["role"] == "user":
first_user = m
break
if first_user:
first_user["content"] = f"[指令] {system_prompt}\n\n[用户输入] {first_user['content']}"
if stream:
return self._stream_deepseek_reasoner(client, model, full_messages)
else:
response = await client.chat.completions.create(
model=model,
messages=full_messages,
max_tokens=8192,
)
reasoning = getattr(response.choices[0].message, 'reasoning_content', '') or ''
content = response.choices[0].message.content or ''
if reasoning:
return f"<think>\n{reasoning}\n</think>\n\n{content}"
return content
async def _stream_deepseek_reasoner(
self, client: AsyncOpenAI, model: str, messages: List[dict],
) -> AsyncGenerator[str, None]:
"""DeepSeek Reasoner 流式输出 - 包含思考过程和最终回答"""
response = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=8192,
stream=True,
)
in_reasoning = False
reasoning_started = False
async for chunk in response:
delta = chunk.choices[0].delta
# 思考过程
reasoning_content = getattr(delta, 'reasoning_content', None)
if reasoning_content:
if not reasoning_started:
reasoning_started = True
in_reasoning = True
yield "<details>\n<summary>💭 思考过程</summary>\n\n"
yield reasoning_content
# 最终回答
if delta.content:
if in_reasoning:
in_reasoning = False
yield "\n</details>\n\n"
yield delta.content
async def _chat_anthropic(
self, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""Anthropic接口调用使用self.anthropic_client"""
return await self._chat_anthropic_with_client(self.anthropic_client, model, messages, system_prompt, stream)
async def _chat_anthropic_with_client(
self, client, model: str, messages: List[dict],
system_prompt: str, stream: bool,
) -> Union[str, AsyncGenerator[str, None]]:
"""Anthropic接口调用"""
if stream:
return self._stream_anthropic_with_client(client, model, messages, system_prompt)
else:
response = await client.messages.create(
model=model,
max_tokens=4096,
system=system_prompt if system_prompt else "You are a helpful assistant.",
messages=messages,
)
return response.content[0].text
async def _stream_anthropic(
self, model: str, messages: List[dict], system_prompt: str,
) -> AsyncGenerator[str, None]:
"""Anthropic流式输出使用self.anthropic_client"""
async for text in self._stream_anthropic_with_client(self.anthropic_client, model, messages, system_prompt):
yield text
async def _stream_anthropic_with_client(
self, client, model: str, messages: List[dict], system_prompt: str,
) -> AsyncGenerator[str, None]:
"""Anthropic流式输出"""
async with client.messages.stream(
model=model,
max_tokens=4096,
system=system_prompt if system_prompt else "You are a helpful assistant.",
messages=messages,
) as stream:
async for text in stream.text_stream:
yield text
# 单例
ai_service = AIService()

0
backend/uploads/.gitkeep Normal file
View File