175 lines
6.8 KiB
Python
175 lines
6.8 KiB
Python
"""FastAPI 应用入口"""
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.staticfiles import StaticFiles
|
||
import os
|
||
|
||
from database import init_db
|
||
from config import UPLOAD_DIR
|
||
import models.system_config # 确保 system_configs 表被创建
|
||
import models.category # 确保 categories 表被创建
|
||
import models.nav_category # 确保 nav_categories 表被创建
|
||
import models.nav_link # 确保 nav_links 表被创建
|
||
import models.project # 确保 projects 表被创建
|
||
import models.shared_api # 确保 shared_api 相关表被创建
|
||
import models.knowledge_base # 确保 kb 相关表被创建
|
||
import models.attachment # 确保 attachments 表被创建
|
||
from routers import auth, requirement, architecture, posts, search, ai_models, bookmarks, users, notifications, upload, admin, nav, projects, shared_api, knowledge_base, web_search, ai_format
|
||
|
||
# 确保上传目录存在
|
||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||
|
||
app = FastAPI(
|
||
title="极码 GeekCode",
|
||
description="极码 GeekCode - 开发者社区、AI工具库、经验知识库",
|
||
version="2.0.0",
|
||
)
|
||
|
||
# CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 静态文件(上传的图片等)
|
||
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
|
||
|
||
# 注册路由
|
||
app.include_router(auth.router, prefix="/api/auth", tags=["认证"])
|
||
app.include_router(requirement.router, prefix="/api/requirement", tags=["需求理解助手"])
|
||
app.include_router(architecture.router, prefix="/api/architecture", tags=["架构选型助手"])
|
||
app.include_router(posts.router, prefix="/api/posts", tags=["经验知识库"])
|
||
app.include_router(search.router, prefix="/api/search", tags=["搜索"])
|
||
app.include_router(ai_models.router)
|
||
app.include_router(ai_models.public_router)
|
||
app.include_router(bookmarks.router, prefix="/api/bookmarks", tags=["网站收藏"])
|
||
app.include_router(users.router, prefix="/api/users", tags=["用户"])
|
||
app.include_router(notifications.router, prefix="/api/notifications", tags=["消息通知"])
|
||
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
|
||
app.include_router(admin.router, prefix="/api/admin", tags=["后台管理"])
|
||
app.include_router(nav.router, prefix="/api/nav", tags=["导航站"])
|
||
app.include_router(projects.router, prefix="/api/projects", tags=["开源项目"])
|
||
app.include_router(shared_api.router, prefix="/api/api-hub", tags=["API Hub"])
|
||
app.include_router(knowledge_base.router, prefix="/api/kb", tags=["团队知识库"])
|
||
app.include_router(web_search.router, prefix="/api/web-search", tags=["联网搜索"])
|
||
app.include_router(ai_format.router)
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
"""应用启动时初始化数据库"""
|
||
init_db()
|
||
_init_default_categories()
|
||
_migrate_user_is_approved()
|
||
_migrate_project_collect_count()
|
||
_migrate_web_search_enabled()
|
||
_migrate_web_search_count()
|
||
|
||
def _init_default_categories():
|
||
"""如果分类表为空,插入默认分类"""
|
||
from database import SessionLocal
|
||
from models.category import Category
|
||
db = SessionLocal()
|
||
try:
|
||
if db.query(Category).count() == 0:
|
||
defaults = ['前端', '后端', '部署', '踩坑', '最佳实践', '工具']
|
||
for i, name in enumerate(defaults):
|
||
db.add(Category(name=name, sort_order=i))
|
||
db.commit()
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def _migrate_user_is_approved():
|
||
"""迁移:给 users 表添加 is_approved 字段(已有用户自动设为已审核)"""
|
||
from database import SessionLocal
|
||
from sqlalchemy import text
|
||
db = SessionLocal()
|
||
try:
|
||
# 检查字段是否存在
|
||
result = db.execute(text(
|
||
"SELECT COUNT(*) FROM information_schema.COLUMNS "
|
||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'users' AND COLUMN_NAME = 'is_approved'"
|
||
))
|
||
exists = result.scalar()
|
||
if not exists:
|
||
# 添加字段,先默认值1让已有用户自动通过审核
|
||
db.execute(text("ALTER TABLE users ADD COLUMN is_approved TINYINT(1) NOT NULL DEFAULT 1"))
|
||
# 再把默认值改为0,新用户需审核
|
||
db.execute(text("ALTER TABLE users ALTER COLUMN is_approved SET DEFAULT 0"))
|
||
db.commit()
|
||
except Exception as e:
|
||
db.rollback()
|
||
print(f"[migrate] is_approved: {e}")
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return {"status": "ok", "message": "极码 GeekCode API 运行中"}
|
||
|
||
|
||
def _migrate_project_collect_count():
|
||
"""迁移:给 projects 表添加 collect_count 字段"""
|
||
from database import SessionLocal
|
||
from sqlalchemy import text
|
||
db = SessionLocal()
|
||
try:
|
||
result = db.execute(text(
|
||
"SELECT COUNT(*) FROM information_schema.COLUMNS "
|
||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'projects' AND COLUMN_NAME = 'collect_count'"
|
||
))
|
||
if not result.scalar():
|
||
db.execute(text("ALTER TABLE projects ADD COLUMN collect_count INT NOT NULL DEFAULT 0"))
|
||
db.commit()
|
||
except Exception as e:
|
||
db.rollback()
|
||
print(f"[migrate] project collect_count: {e}")
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def _migrate_web_search_enabled():
|
||
"""迁移:给 ai_model_configs 表添加 web_search_enabled 字段"""
|
||
from database import SessionLocal
|
||
from sqlalchemy import text
|
||
db = SessionLocal()
|
||
try:
|
||
result = db.execute(text(
|
||
"SELECT COUNT(*) FROM information_schema.COLUMNS "
|
||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'ai_model_configs' AND COLUMN_NAME = 'web_search_enabled'"
|
||
))
|
||
if not result.scalar():
|
||
db.execute(text("ALTER TABLE ai_model_configs ADD COLUMN web_search_enabled TINYINT(1) NOT NULL DEFAULT 0"))
|
||
db.commit()
|
||
except Exception as e:
|
||
db.rollback()
|
||
print(f"[migrate] web_search_enabled: {e}")
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def _migrate_web_search_count():
|
||
"""迁移:给 ai_model_configs 表添加 web_search_count 字段"""
|
||
from database import SessionLocal
|
||
from sqlalchemy import text
|
||
db = SessionLocal()
|
||
try:
|
||
result = db.execute(text(
|
||
"SELECT COUNT(*) FROM information_schema.COLUMNS "
|
||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'ai_model_configs' AND COLUMN_NAME = 'web_search_count'"
|
||
))
|
||
if not result.scalar():
|
||
db.execute(text("ALTER TABLE ai_model_configs ADD COLUMN web_search_count INT NOT NULL DEFAULT 5"))
|
||
db.commit()
|
||
except Exception as e:
|
||
db.rollback()
|
||
print(f"[migrate] web_search_count: {e}")
|
||
finally:
|
||
db.close()
|