docs(readme): 编写项目README文档,描述功能与架构

- 完整撰写玉宗珠宝设计大师项目README,介绍项目概况及核心功能
- 说明用户认证系统实现及优势,包含JWT鉴权和密码加密细节
- 详细描述品类管理系统,支持多流程类型和多种玉石品类
- 说明设计图生成方案及技术,包含Pillow生成示例及字体支持
- 介绍设计管理功能,支持分页浏览、预览、下载和删除设计
- 个人信息管理模块说明,涵盖昵称、手机号、密码的安全修改
- 绘制业务流程图和关键数据流图,清晰展现系统架构与数据流
- 提供详细API调用链路及参数说明,涵盖用户、品类、设计接口
- 列明技术栈及版本,包含前后端框架、ORM、认证、加密等工具
- 展示目录结构,标明后端与前端项目布局
- 规划本地开发环境与启动步骤,包括数据库初始化及运行命令
- 说明服务器部署流程和Nginx配置方案
- 详细数据库表结构说明及环境变量配置指导
- 汇总常用开发及测试命令,方便开发调试与部署管理
This commit is contained in:
changyoutongxue
2026-03-27 13:10:17 +08:00
commit e3ff55b4db
69 changed files with 8551 additions and 0 deletions

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@
# 玉宗 - 珠宝设计大师 后端应用

28
backend/app/config.py Normal file
View File

@@ -0,0 +1,28 @@
"""
应用配置管理
使用 pydantic-settings 从环境变量读取配置
"""
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""应用配置"""
DATABASE_URL: str = "mysql+pymysql://root:password@localhost:3306/yuzong"
SECRET_KEY: str = "your-secret-key-change-this"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440
UPLOAD_DIR: str = "uploads"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
@lru_cache()
def get_settings() -> Settings:
"""获取配置单例"""
return Settings()
settings = get_settings()

37
backend/app/database.py Normal file
View File

@@ -0,0 +1,37 @@
"""
数据库连接配置
使用 SQLAlchemy 2.0 同步方式
"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from typing import Generator
from .config import settings
# 创建数据库引擎
engine = create_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
pool_recycle=3600,
echo=False
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基类
Base = declarative_base()
def get_db() -> Generator:
"""
数据库会话依赖注入
用于 FastAPI 的依赖注入系统
"""
db = SessionLocal()
try:
yield db
finally:
db.close()

67
backend/app/main.py Normal file
View File

@@ -0,0 +1,67 @@
"""
玉宗 - 珠宝设计大师 后端服务入口
"""
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from .config import settings
from .routers import categories, designs, users
from .routers import auth
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时:创建 uploads 目录
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
print(f"✅ 上传目录已准备: {settings.UPLOAD_DIR}")
yield
# 关闭时:清理资源(如需要)
print("👋 应用已关闭")
# 创建 FastAPI 应用实例
app = FastAPI(
title="玉宗 - 珠宝设计大师",
description="AI驱动的珠宝设计微信小程序后端服务",
version="1.0.0",
lifespan=lifespan
)
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"], # 生产环境应限制具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""根路径健康检查"""
return {
"message": "玉宗 - 珠宝设计大师 API",
"version": "1.0.0",
"status": "running"
}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "healthy"}
# 注册路由
app.include_router(auth.router)
app.include_router(categories.router)
app.include_router(designs.router)
app.include_router(users.router)
# 配置静态文件服务
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")

View File

@@ -0,0 +1,17 @@
"""
数据库模型
导出所有模型,确保 Base.metadata 包含所有表
"""
from ..database import Base
from .user import User
from .category import Category, SubType, Color
from .design import Design
__all__ = [
"Base",
"User",
"Category",
"SubType",
"Color",
"Design"
]

View File

@@ -0,0 +1,64 @@
"""
品类相关模型
包含:品类、子类型、颜色
"""
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from ..database import Base
class Category(Base):
"""品类表"""
__tablename__ = "categories"
id = Column(Integer, primary_key=True, autoincrement=True, comment="品类ID")
name = Column(String(50), nullable=False, comment="品类名称")
icon = Column(String(255), nullable=True, comment="品类图标")
sort_order = Column(Integer, default=0, comment="排序")
flow_type = Column(String(20), nullable=False, comment="流程类型full/size_color/simple")
# 关联关系
sub_types = relationship("SubType", back_populates="category")
colors = relationship("Color", back_populates="category")
designs = relationship("Design", back_populates="category")
def __repr__(self):
return f"<Category(id={self.id}, name='{self.name}')>"
class SubType(Base):
"""子类型表"""
__tablename__ = "sub_types"
id = Column(Integer, primary_key=True, autoincrement=True, comment="子类型ID")
category_id = Column(Integer, ForeignKey("categories.id"), nullable=False, comment="所属品类")
name = Column(String(50), nullable=False, comment="名称")
description = Column(String(255), nullable=True, comment="描述")
preview_image = Column(String(255), nullable=True, comment="预览图")
sort_order = Column(Integer, default=0, comment="排序")
# 关联关系
category = relationship("Category", back_populates="sub_types")
designs = relationship("Design", back_populates="sub_type")
def __repr__(self):
return f"<SubType(id={self.id}, name='{self.name}')>"
class Color(Base):
"""颜色表"""
__tablename__ = "colors"
id = Column(Integer, primary_key=True, autoincrement=True, comment="颜色ID")
category_id = Column(Integer, ForeignKey("categories.id"), nullable=False, comment="适用品类")
name = Column(String(50), nullable=False, comment="颜色名称")
hex_code = Column(String(7), nullable=False, comment="色值")
sort_order = Column(Integer, default=0, comment="排序")
# 关联关系
category = relationship("Category", back_populates="colors")
designs = relationship("Design", back_populates="color")
def __repr__(self):
return f"<Color(id={self.id}, name='{self.name}', hex_code='{self.hex_code}')>"

View File

@@ -0,0 +1,39 @@
"""
设计作品模型
"""
from sqlalchemy import Column, BigInteger, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from ..database import Base
class Design(Base):
"""设计作品表"""
__tablename__ = "designs"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="设计ID")
user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False, comment="用户ID")
category_id = Column(Integer, ForeignKey("categories.id"), nullable=False, comment="品类ID")
sub_type_id = Column(Integer, ForeignKey("sub_types.id"), nullable=True, comment="子类型ID")
color_id = Column(Integer, ForeignKey("colors.id"), nullable=True, comment="颜色ID")
prompt = Column(Text, nullable=False, comment="设计需求")
carving_technique = Column(String(50), nullable=True, comment="雕刻工艺")
design_style = Column(String(50), nullable=True, comment="设计风格")
motif = Column(String(100), nullable=True, comment="题材纹样")
size_spec = Column(String(100), nullable=True, comment="尺寸规格")
surface_finish = Column(String(50), nullable=True, comment="表面处理")
usage_scene = Column(String(50), nullable=True, comment="用途场景")
image_url = Column(String(255), nullable=True, comment="设计图URL")
status = Column(String(20), default="generating", comment="状态")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
# 关联关系
user = relationship("User", back_populates="designs")
category = relationship("Category", back_populates="designs")
sub_type = relationship("SubType", back_populates="designs")
color = relationship("Color", back_populates="designs")
def __repr__(self):
return f"<Design(id={self.id}, status='{self.status}')>"

View File

@@ -0,0 +1,28 @@
"""
用户模型
"""
from sqlalchemy import Column, BigInteger, String, DateTime
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from ..database import Base
class User(Base):
"""用户表"""
__tablename__ = "users"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="用户ID")
username = Column(String(50), unique=True, nullable=False, comment="用户名")
phone = Column(String(20), unique=True, nullable=True, comment="手机号")
hashed_password = Column(String(255), nullable=False, comment="加密密码")
nickname = Column(String(50), nullable=True, comment="昵称")
avatar = Column(String(255), nullable=True, comment="头像URL")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
# 关联关系
designs = relationship("Design", back_populates="user")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"

View File

@@ -0,0 +1,8 @@
# API 路由模块
from . import categories, designs, users
__all__ = [
"categories",
"designs",
"users",
]

View File

@@ -0,0 +1,63 @@
"""
认证路由
提供用户注册、登录和获取当前用户信息的 API
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..database import get_db
from ..schemas.user import UserCreate, UserLogin, UserResponse, Token
from ..services.auth_service import register_user, authenticate_user
from ..utils.deps import get_current_user
from ..utils.security import create_access_token
from ..models.user import User
router = APIRouter(prefix="/api/auth", tags=["认证"])
@router.post("/register", response_model=UserResponse)
def register(user_data: UserCreate, db: Session = Depends(get_db)):
"""
用户注册
创建新用户账号,用户名必须唯一
"""
try:
user = register_user(db, user_data)
return user
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
@router.post("/login", response_model=Token)
def login(user_data: UserLogin, db: Session = Depends(get_db)):
"""
用户登录
验证用户名和密码,返回 JWT access token
"""
user = authenticate_user(db, user_data.username, user_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 生成 JWT tokensub 字段存储用户 ID
access_token = create_access_token(data={"sub": str(user.id)})
return Token(access_token=access_token, token_type="bearer")
@router.get("/me", response_model=UserResponse)
def get_me(current_user: User = Depends(get_current_user)):
"""
获取当前登录用户信息
需要认证,从 token 中解析用户身份
"""
return current_user

View File

@@ -0,0 +1,71 @@
"""
品类相关路由
提供品类、子类型、颜色的查询接口
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from ..database import get_db
from ..models import Category, SubType, Color
from ..schemas import CategoryResponse, SubTypeResponse, ColorResponse
router = APIRouter(prefix="/api/categories", tags=["品类"])
@router.get("", response_model=List[CategoryResponse])
def get_categories(db: Session = Depends(get_db)):
"""
获取所有品类列表
按 sort_order 排序,无需认证
"""
categories = db.query(Category).order_by(Category.sort_order).all()
return categories
@router.get("/{category_id}/sub-types", response_model=List[SubTypeResponse])
def get_category_sub_types(
category_id: int,
db: Session = Depends(get_db)
):
"""
获取品类下的子类型
无需认证
"""
# 检查品类是否存在
category = db.query(Category).filter(Category.id == category_id).first()
if not category:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品类不存在"
)
sub_types = db.query(SubType).filter(
SubType.category_id == category_id
).order_by(SubType.sort_order).all()
return sub_types
@router.get("/{category_id}/colors", response_model=List[ColorResponse])
def get_category_colors(
category_id: int,
db: Session = Depends(get_db)
):
"""
获取品类下的颜色选项
无需认证
"""
# 检查品类是否存在
category = db.query(Category).filter(Category.id == category_id).first()
if not category:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品类不存在"
)
colors = db.query(Color).filter(
Color.category_id == category_id
).order_by(Color.sort_order).all()
return colors

View File

@@ -0,0 +1,201 @@
"""
设计相关路由
提供设计生成、查询、删除、下载接口
"""
import os
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from ..database import get_db
from ..models import User, Design
from ..schemas import DesignCreate, DesignResponse, DesignListResponse
from ..utils.deps import get_current_user
from ..services import design_service
router = APIRouter(prefix="/api/designs", tags=["设计"])
def design_to_response(design: Design) -> DesignResponse:
"""将 Design 模型转换为响应格式"""
return DesignResponse(
id=design.id,
user_id=design.user_id,
category={
"id": design.category.id,
"name": design.category.name,
"icon": design.category.icon,
"sort_order": design.category.sort_order,
"flow_type": design.category.flow_type
},
sub_type={
"id": design.sub_type.id,
"category_id": design.sub_type.category_id,
"name": design.sub_type.name,
"description": design.sub_type.description,
"preview_image": design.sub_type.preview_image,
"sort_order": design.sub_type.sort_order
} if design.sub_type else None,
color={
"id": design.color.id,
"category_id": design.color.category_id,
"name": design.color.name,
"hex_code": design.color.hex_code,
"sort_order": design.color.sort_order
} if design.color else None,
prompt=design.prompt,
carving_technique=design.carving_technique,
design_style=design.design_style,
motif=design.motif,
size_spec=design.size_spec,
surface_finish=design.surface_finish,
usage_scene=design.usage_scene,
image_url=design.image_url,
status=design.status,
created_at=design.created_at,
updated_at=design.updated_at
)
@router.post("/generate", response_model=DesignResponse)
def generate_design(
design_data: DesignCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
提交设计生成请求
需要认证
"""
try:
design = design_service.create_design(
db=db,
user_id=current_user.id,
design_data=design_data
)
return design_to_response(design)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
@router.get("", response_model=DesignListResponse)
def get_designs(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取当前用户的设计历史列表(分页)
需要认证
"""
designs, total = design_service.get_user_designs(
db=db,
user_id=current_user.id,
page=page,
page_size=page_size
)
return DesignListResponse(
items=[design_to_response(d) for d in designs],
total=total,
page=page,
page_size=page_size
)
@router.get("/{design_id}", response_model=DesignResponse)
def get_design(
design_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取设计详情
只能查看自己的设计,非本人设计返回 404
"""
design = design_service.get_design_by_id(
db=db,
design_id=design_id,
user_id=current_user.id
)
if not design:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="设计不存在"
)
return design_to_response(design)
@router.delete("/{design_id}")
def delete_design(
design_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
删除设计
只能删除自己的设计,非本人设计返回 404
"""
success = design_service.delete_design(
db=db,
design_id=design_id,
user_id=current_user.id
)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="设计不存在"
)
return {"message": "删除成功"}
@router.get("/{design_id}/download")
def download_design(
design_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
下载设计图
只能下载自己的设计,非本人设计返回 404
"""
design = design_service.get_design_by_id(
db=db,
design_id=design_id,
user_id=current_user.id
)
if not design:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="设计不存在"
)
if not design.image_url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="设计图片不存在"
)
# 转换 URL 为文件路径
file_path = design.image_url.lstrip("/")
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="设计图片文件不存在"
)
return FileResponse(
path=file_path,
filename=f"design_{design_id}.png",
media_type="image/png"
)

View File

@@ -0,0 +1,75 @@
"""
用户相关路由
提供用户信息更新、密码修改接口
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..database import get_db
from ..models import User
from ..schemas import UserResponse, UserUpdate, PasswordChange
from ..utils.deps import get_current_user
from ..utils.security import verify_password, get_password_hash
router = APIRouter(prefix="/api/users", tags=["用户"])
@router.put("/profile", response_model=UserResponse)
def update_profile(
user_data: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
更新个人信息
需要认证
"""
# 更新非空字段
if user_data.nickname is not None:
current_user.nickname = user_data.nickname
if user_data.phone is not None:
# 检查手机号是否已被其他用户使用
if user_data.phone:
existing_user = db.query(User).filter(
User.phone == user_data.phone,
User.id != current_user.id
).first()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="手机号已被使用"
)
current_user.phone = user_data.phone
if user_data.avatar is not None:
current_user.avatar = user_data.avatar
db.commit()
db.refresh(current_user)
return current_user
@router.put("/password")
def change_password(
password_data: PasswordChange,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
修改密码
需要认证,旧密码错误返回 400
"""
# 验证旧密码
if not verify_password(password_data.old_password, current_user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="旧密码错误"
)
# 更新密码
current_user.hashed_password = get_password_hash(password_data.new_password)
db.commit()
return {"message": "密码修改成功"}

View File

@@ -0,0 +1,25 @@
"""
Pydantic Schemas
导出所有 Schema 类型
"""
from .user import UserCreate, UserLogin, UserResponse, Token, UserUpdate, PasswordChange
from .category import CategoryResponse, SubTypeResponse, ColorResponse
from .design import DesignCreate, DesignResponse, DesignListResponse
__all__ = [
# User schemas
"UserCreate",
"UserLogin",
"UserResponse",
"Token",
"UserUpdate",
"PasswordChange",
# Category schemas
"CategoryResponse",
"SubTypeResponse",
"ColorResponse",
# Design schemas
"DesignCreate",
"DesignResponse",
"DesignListResponse",
]

View File

@@ -0,0 +1,42 @@
"""
品类相关 Pydantic Schemas
"""
from pydantic import BaseModel
from typing import Optional
class CategoryResponse(BaseModel):
"""品类响应"""
id: int
name: str
icon: Optional[str] = None
sort_order: int
flow_type: str
class Config:
from_attributes = True
class SubTypeResponse(BaseModel):
"""子类型响应"""
id: int
category_id: int
name: str
description: Optional[str] = None
preview_image: Optional[str] = None
sort_order: int
class Config:
from_attributes = True
class ColorResponse(BaseModel):
"""颜色响应"""
id: int
category_id: int
name: str
hex_code: str
sort_order: int
class Config:
from_attributes = True

View File

@@ -0,0 +1,53 @@
"""
设计作品相关 Pydantic Schemas
"""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional, List
from .category import CategoryResponse, SubTypeResponse, ColorResponse
class DesignCreate(BaseModel):
"""创建设计请求"""
category_id: int = Field(..., description="品类ID")
sub_type_id: Optional[int] = Field(None, description="子类型ID")
color_id: Optional[int] = Field(None, description="颜色ID")
prompt: str = Field(..., min_length=1, max_length=2000, description="设计需求")
carving_technique: Optional[str] = Field(None, max_length=50, description="雕刻工艺")
design_style: Optional[str] = Field(None, max_length=50, description="设计风格")
motif: Optional[str] = Field(None, max_length=100, description="题材纹样")
size_spec: Optional[str] = Field(None, max_length=100, description="尺寸规格")
surface_finish: Optional[str] = Field(None, max_length=50, description="表面处理")
usage_scene: Optional[str] = Field(None, max_length=50, description="用途场景")
class DesignResponse(BaseModel):
"""设计作品响应"""
id: int
user_id: int
category: CategoryResponse
sub_type: Optional[SubTypeResponse] = None
color: Optional[ColorResponse] = None
prompt: str
carving_technique: Optional[str] = None
design_style: Optional[str] = None
motif: Optional[str] = None
size_spec: Optional[str] = None
surface_finish: Optional[str] = None
usage_scene: Optional[str] = None
image_url: Optional[str] = None
status: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class DesignListResponse(BaseModel):
"""设计作品列表响应"""
items: List[DesignResponse]
total: int
page: int
page_size: int

View File

@@ -0,0 +1,51 @@
"""
用户相关 Pydantic Schemas
"""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional
class UserCreate(BaseModel):
"""用户注册请求"""
username: str = Field(..., min_length=2, max_length=50, description="用户名")
password: str = Field(..., min_length=6, max_length=100, description="密码")
nickname: Optional[str] = Field(None, max_length=50, description="昵称")
class UserLogin(BaseModel):
"""用户登录请求"""
username: str = Field(..., description="用户名")
password: str = Field(..., description="密码")
class UserResponse(BaseModel):
"""用户响应"""
id: int
username: str
nickname: Optional[str] = None
phone: Optional[str] = None
avatar: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
class Token(BaseModel):
"""认证令牌响应"""
access_token: str
token_type: str = "bearer"
class UserUpdate(BaseModel):
"""用户信息更新请求"""
nickname: Optional[str] = Field(None, max_length=50, description="昵称")
phone: Optional[str] = Field(None, max_length=20, description="手机号")
avatar: Optional[str] = Field(None, max_length=255, description="头像URL")
class PasswordChange(BaseModel):
"""修改密码请求"""
old_password: str = Field(..., description="旧密码")
new_password: str = Field(..., min_length=6, max_length=100, description="新密码")

View File

@@ -0,0 +1,8 @@
# 业务服务模块
from . import design_service
from .mock_generator import generate_mock_design
__all__ = [
"design_service",
"generate_mock_design",
]

View File

@@ -0,0 +1,67 @@
"""
认证服务
提供用户注册和登录业务逻辑
"""
from typing import Optional
from sqlalchemy.orm import Session
from ..models.user import User
from ..schemas.user import UserCreate
from ..utils.security import get_password_hash, verify_password
def register_user(db: Session, user_data: UserCreate) -> User:
"""
注册新用户
Args:
db: 数据库会话
user_data: 用户注册数据
Returns:
创建的用户对象
Raises:
ValueError: 用户名已存在时抛出
"""
# 检查用户名是否已存在
existing_user = db.query(User).filter(User.username == user_data.username).first()
if existing_user:
raise ValueError("用户名已存在")
# 创建新用户,密码加密存储
db_user = User(
username=user_data.username,
hashed_password=get_password_hash(user_data.password),
nickname=user_data.nickname or user_data.username
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""
验证用户登录
Args:
db: 数据库会话
username: 用户名
password: 明文密码
Returns:
验证成功返回用户对象,失败返回 None
"""
# 查询用户
user = db.query(User).filter(User.username == username).first()
if not user:
return None
# 验证密码
if not verify_password(password, user.hashed_password):
return None
return user

View File

@@ -0,0 +1,150 @@
"""
设计服务
处理设计相关的业务逻辑
"""
import os
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import desc
from ..models import Design, Category, SubType, Color
from ..schemas import DesignCreate
from ..config import settings
from .mock_generator import generate_mock_design
def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design:
"""
创建设计记录
1. 创建设计记录status=generating
2. 调用 mock_generator 生成图片
3. 更新设计记录status=completed, image_url
4. 返回设计对象
"""
# 获取关联信息
category = db.query(Category).filter(Category.id == design_data.category_id).first()
if not category:
raise ValueError(f"品类不存在: {design_data.category_id}")
sub_type = None
if design_data.sub_type_id:
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
color = None
if design_data.color_id:
color = db.query(Color).filter(Color.id == design_data.color_id).first()
# 创建设计记录
design = Design(
user_id=user_id,
category_id=design_data.category_id,
sub_type_id=design_data.sub_type_id,
color_id=design_data.color_id,
prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
status="generating"
)
db.add(design)
db.flush() # 获取 ID
# 生成图片
save_path = os.path.join(settings.UPLOAD_DIR, "designs", f"{design.id}.png")
image_url = generate_mock_design(
category_name=category.name,
sub_type_name=sub_type.name if sub_type else None,
color_name=color.name if color else None,
prompt=design_data.prompt,
save_path=save_path,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
)
# 更新设计记录
design.image_url = image_url
design.status = "completed"
db.commit()
db.refresh(design)
return design
def get_user_designs(
db: Session,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Tuple[List[Design], int]:
"""
分页查询用户设计历史
Returns:
(设计列表, 总数)
"""
query = db.query(Design).filter(Design.user_id == user_id)
# 获取总数
total = query.count()
# 分页查询,按创建时间倒序
offset = (page - 1) * page_size
designs = query.order_by(desc(Design.created_at)).offset(offset).limit(page_size).all()
return designs, total
def get_design_by_id(db: Session, design_id: int, user_id: int) -> Optional[Design]:
"""
获取单个设计
只返回属于该用户的设计
"""
return db.query(Design).filter(
Design.id == design_id,
Design.user_id == user_id
).first()
def delete_design(db: Session, design_id: int, user_id: int) -> bool:
"""
删除设计
1. 查找设计(必须属于该用户)
2. 删除图片文件
3. 删除数据库记录
Returns:
是否删除成功
"""
design = db.query(Design).filter(
Design.id == design_id,
Design.user_id == user_id
).first()
if not design:
return False
# 删除图片文件
if design.image_url:
# image_url 格式: /uploads/designs/1001.png
# 转换为实际文件路径
file_path = design.image_url.lstrip("/")
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception:
pass # 忽略删除失败
# 删除数据库记录
db.delete(design)
db.commit()
return True

View File

@@ -0,0 +1,222 @@
"""
Mock 图片生成服务
使用 Pillow 生成带文字的占位设计图
"""
import os
from typing import Optional, Tuple, Union
from PIL import Image, ImageDraw, ImageFont
# 颜色映射表(中文颜色名 -> 十六进制)
COLOR_MAP = {
# 和田玉国标色种
"白玉": "#FEFEF2",
"青白玉": "#E8EDE4",
"青玉": "#7A8B6E",
"碧玉": "#2D5F2D",
"翠青": "#6BAF8D",
"黄玉": "#D4A843",
"糖玉": "#C4856C",
"墨玉": "#2C2C2C",
"藕粉": "#E8B4B8",
"烟紫": "#8B7D9B",
# 原有颜色
"糖白": "#F5F0E8",
# 通用颜色
"白色": "#FFFFFF",
"黑色": "#333333",
"红色": "#C41E3A",
"绿色": "#228B22",
"蓝色": "#4169E1",
"黄色": "#FFD700",
"紫色": "#9370DB",
"粉色": "#FFB6C1",
"橙色": "#FF8C00",
}
# 默认背景色(浅灰)
DEFAULT_BG_COLOR = "#E8E4DF"
def hex_to_rgb(hex_color: str) -> Tuple[int, int, int]:
"""将十六进制颜色转换为 RGB 元组"""
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
def get_contrast_text_color(bg_color: str) -> str:
"""根据背景色计算合适的文字颜色(黑或白)"""
r, g, b = hex_to_rgb(bg_color)
# 使用亮度公式
brightness = (r * 299 + g * 587 + b * 114) / 1000
return "#333333" if brightness > 128 else "#FFFFFF"
def get_font(size: int = 24) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
"""
获取字体,优先使用系统中文字体
"""
# 常见中文字体路径
font_paths = [
# macOS
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/STHeiti Light.ttc",
"/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
"/Library/Fonts/Arial Unicode.ttf",
# Linux
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
# Windows
"C:\\Windows\\Fonts\\msyh.ttc",
"C:\\Windows\\Fonts\\simsun.ttc",
]
for font_path in font_paths:
if os.path.exists(font_path):
try:
return ImageFont.truetype(font_path, size)
except Exception:
continue
# 回退到默认字体
return ImageFont.load_default()
def generate_mock_design(
category_name: str,
sub_type_name: Optional[str],
color_name: Optional[str],
prompt: str,
save_path: str,
carving_technique: Optional[str] = None,
design_style: Optional[str] = None,
motif: Optional[str] = None,
size_spec: Optional[str] = None,
surface_finish: Optional[str] = None,
usage_scene: Optional[str] = None,
) -> str:
"""
生成 Mock 设计图
Args:
category_name: 品类名称
sub_type_name: 子类型名称(可选)
color_name: 颜色名称(可选)
prompt: 用户设计需求
save_path: 保存路径
Returns:
相对 URL 路径,如 /uploads/designs/1001.png
"""
# 确定背景色
if color_name and color_name in COLOR_MAP:
bg_color = COLOR_MAP[color_name]
elif color_name:
# 尝试直接使用颜色名(可能是十六进制)
bg_color = color_name if color_name.startswith("#") else DEFAULT_BG_COLOR
else:
bg_color = DEFAULT_BG_COLOR
# 创建图片
width, height = 800, 800
bg_rgb = hex_to_rgb(bg_color)
image = Image.new("RGB", (width, height), bg_rgb)
draw = ImageDraw.Draw(image)
# 获取文字颜色(与背景对比)
text_color = get_contrast_text_color(bg_color)
text_rgb = hex_to_rgb(text_color)
# 获取字体
title_font = get_font(48)
info_font = get_font(32)
prompt_font = get_font(28)
# 绘制标题
title = "玉宗设计"
draw.text((width // 2, 100), title, font=title_font, fill=text_rgb, anchor="mm")
# 绘制分隔线
line_y = 160
draw.line([(100, line_y), (700, line_y)], fill=text_rgb, width=2)
# 绘制品类信息
y_position = 220
info_lines = [f"品类: {category_name}"]
if sub_type_name:
info_lines.append(f"类型: {sub_type_name}")
if color_name:
info_lines.append(f"颜色: {color_name}")
if carving_technique:
info_lines.append(f"工艺: {carving_technique}")
if design_style:
info_lines.append(f"风格: {design_style}")
if motif:
info_lines.append(f"题材: {motif}")
if size_spec:
info_lines.append(f"尺寸: {size_spec}")
if surface_finish:
info_lines.append(f"表面: {surface_finish}")
if usage_scene:
info_lines.append(f"用途: {usage_scene}")
for line in info_lines:
draw.text((width // 2, y_position), line, font=info_font, fill=text_rgb, anchor="mm")
y_position += 50
# 绘制分隔线
y_position += 20
draw.line([(100, y_position), (700, y_position)], fill=text_rgb, width=1)
y_position += 40
# 绘制用户需求标题
draw.text((width // 2, y_position), "设计需求:", font=info_font, fill=text_rgb, anchor="mm")
y_position += 50
# 绘制用户需求文本(自动换行)
max_chars_per_line = 20
prompt_lines = []
current_line = ""
for char in prompt:
current_line += char
if len(current_line) >= max_chars_per_line:
prompt_lines.append(current_line)
current_line = ""
if current_line:
prompt_lines.append(current_line)
# 限制最多显示 5 行
for line in prompt_lines[:5]:
draw.text((width // 2, y_position), line, font=prompt_font, fill=text_rgb, anchor="mm")
y_position += 40
if len(prompt_lines) > 5:
draw.text((width // 2, y_position), "...", font=prompt_font, fill=text_rgb, anchor="mm")
# 绘制底部装饰
draw.rectangle([(50, 720), (750, 750)], outline=text_rgb, width=2)
draw.text((width // 2, 735), "AI Generated Mock Design", font=get_font(20), fill=text_rgb, anchor="mm")
# 确保目录存在
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# 保存图片
image.save(save_path, "PNG")
# 返回相对 URL 路径
# save_path 格式类似 uploads/designs/1001.png
# 需要转换为 /uploads/designs/1001.png
relative_path = save_path.replace("\\", "/")
if not relative_path.startswith("/"):
relative_path = "/" + relative_path
return relative_path

View File

@@ -0,0 +1 @@
# 工具函数模块

58
backend/app/utils/deps.py Normal file
View File

@@ -0,0 +1,58 @@
"""
认证依赖注入
提供用户认证相关的 FastAPI 依赖
"""
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from ..config import settings
from ..database import get_db
from ..models.user import User
# OAuth2 密码认证方案tokenUrl 指向登录接口
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
) -> User:
"""
获取当前登录用户
从 JWT token 中解析用户 ID查询数据库返回用户对象
Args:
token: JWT access token
db: 数据库会话
Returns:
当前登录的用户对象
Raises:
HTTPException: token 无效或用户不存在时抛出 401
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 解码 JWT token
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# 从数据库查询用户
user = db.query(User).filter(User.id == int(user_id)).first()
if user is None:
raise credentials_exception
return user

View File

@@ -0,0 +1,63 @@
"""
安全工具函数
包含 JWT 令牌创建和密码加密验证
"""
from datetime import datetime, timedelta
from typing import Optional, Any
from jose import jwt
from passlib.context import CryptContext
from ..config import settings
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
创建 JWT access token
Args:
data: 要编码到 token 中的数据
expires_delta: token 过期时间,默认使用配置中的时间
Returns:
编码后的 JWT token 字符串
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
验证密码
Args:
plain_password: 明文密码
hashed_password: 哈希后的密码
Returns:
密码是否匹配
"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""
对密码进行哈希
Args:
password: 明文密码
Returns:
哈希后的密码字符串
"""
return pwd_context.hash(password)