137 lines
5.0 KiB
Python
137 lines
5.0 KiB
Python
"""认证路由"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.orm import Session
|
||
from passlib.context import CryptContext
|
||
from jose import JWTError, jwt
|
||
from datetime import datetime, timedelta
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
|
||
from database import get_db
|
||
from models.user import User
|
||
from schemas.user import UserRegister, UserLogin, UserResponse, TokenResponse, UserUpdate
|
||
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
|
||
|
||
router = APIRouter()
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
security = HTTPBearer()
|
||
|
||
|
||
def create_access_token(data: dict) -> str:
|
||
"""创建JWT Token"""
|
||
to_encode = data.copy()
|
||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
to_encode.update({"exp": expire})
|
||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||
|
||
|
||
def get_current_user(
|
||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
db: Session = Depends(get_db),
|
||
) -> User:
|
||
"""从Token获取当前用户"""
|
||
token = credentials.credentials
|
||
try:
|
||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||
user_id = payload.get("sub")
|
||
if user_id is None:
|
||
raise HTTPException(status_code=401, detail="无效的认证凭据")
|
||
user_id = int(user_id)
|
||
except JWTError:
|
||
raise HTTPException(status_code=401, detail="无效的认证凭据")
|
||
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if user is None:
|
||
raise HTTPException(status_code=401, detail="用户不存在")
|
||
return user
|
||
|
||
|
||
def get_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
||
"""要求当前用户是管理员"""
|
||
if not current_user.is_admin:
|
||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||
return current_user
|
||
|
||
|
||
@router.post("/register")
|
||
def register(data: UserRegister, db: Session = Depends(get_db)):
|
||
"""用户注册(需管理员审核后才可使用)"""
|
||
# 检查用户名是否已存在
|
||
if db.query(User).filter(User.username == data.username).first():
|
||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||
if db.query(User).filter(User.email == data.email).first():
|
||
raise HTTPException(status_code=400, detail="邮箱已被注册")
|
||
|
||
# 创建用户(is_approved 默认 False,等待审核)
|
||
user = User(
|
||
username=data.username,
|
||
email=data.email,
|
||
password_hash=pwd_context.hash(data.password),
|
||
)
|
||
db.add(user)
|
||
db.commit()
|
||
db.refresh(user)
|
||
|
||
return {"message": "注册成功,请等待管理员审核通过后即可登录使用"}
|
||
|
||
|
||
@router.post("/login", response_model=TokenResponse)
|
||
def login(data: UserLogin, db: Session = Depends(get_db)):
|
||
"""用户登录"""
|
||
user = db.query(User).filter(User.username == data.username).first()
|
||
if not user or not pwd_context.verify(data.password, user.password_hash):
|
||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||
|
||
if getattr(user, 'is_banned', False):
|
||
raise HTTPException(status_code=403, detail="账号已被封禁,请联系管理员")
|
||
|
||
if not getattr(user, 'is_approved', False):
|
||
raise HTTPException(status_code=403, detail="账号尚未通过审核,请耐心等待管理员审核")
|
||
|
||
token = create_access_token({"sub": str(user.id)})
|
||
return TokenResponse(
|
||
access_token=token,
|
||
user=UserResponse.model_validate(user),
|
||
)
|
||
|
||
|
||
@router.get("/me", response_model=UserResponse)
|
||
def get_me(current_user: User = Depends(get_current_user)):
|
||
"""获取当前用户信息"""
|
||
return UserResponse.model_validate(current_user)
|
||
|
||
|
||
@router.put("/profile", response_model=UserResponse)
|
||
def update_profile(
|
||
data: UserUpdate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""更新个人资料"""
|
||
# 修改用户名
|
||
if data.username and data.username != current_user.username:
|
||
if db.query(User).filter(User.username == data.username, User.id != current_user.id).first():
|
||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||
current_user.username = data.username
|
||
|
||
# 修改邮箱
|
||
if data.email and data.email != current_user.email:
|
||
if db.query(User).filter(User.email == data.email, User.id != current_user.id).first():
|
||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||
current_user.email = data.email
|
||
|
||
# 修改头像
|
||
if data.avatar is not None:
|
||
current_user.avatar = data.avatar
|
||
|
||
# 修改密码
|
||
if data.new_password:
|
||
if not data.old_password:
|
||
raise HTTPException(status_code=400, detail="请输入当前密码")
|
||
if not pwd_context.verify(data.old_password, current_user.password_hash):
|
||
raise HTTPException(status_code=400, detail="当前密码错误")
|
||
current_user.password_hash = pwd_context.hash(data.new_password)
|
||
|
||
db.commit()
|
||
db.refresh(current_user)
|
||
return UserResponse.model_validate(current_user)
|