feat(ai): 支持双模型多视角AI设计生图与后台管理系统

- 实现AI多视角设计图生成功能,支持6个可选设计参数配置
- 集成SiliconFlow FLUX.1与火山引擎Seedream 4.5双模型切换
- 构建专业中文转英文prompt系统,提升AI生成质量
- 前端设计预览支持多视角切换与视角指示器展示
- 增加多视角设计图片DesignImage模型关联及存储
- 后端设计服务异步调用AI接口,失败时降级生成mock图
- 新增管理员后台管理路由及完整的权限校验机制
- 实现后台模块:仪表盘、系统配置、用户/品类/设计管理
- 配置数据库系统配置表,支持动态AI配置及热更新
- 增加用户管理员标识字段,管理后台登录鉴权支持
- 更新API接口支持多视角设计参数及后台管理接口
- 优化设计删除逻辑,删除多视角相关图片文件
- 前端新增管理后台页面与路由,布局样式独立分离
- 更新环境变量增加AI模型相关Key与参数配置说明
- 引入httpx异步HTTP客户端用于AI接口调用及图片下载
- README文档完善AI多视角生图与后台管理详细功能与流程说明
This commit is contained in:
2026-03-27 15:29:50 +08:00
parent e3ff55b4db
commit 032c43525a
41 changed files with 3756 additions and 81 deletions

View File

@@ -13,6 +13,13 @@ class Settings(BaseSettings):
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440
UPLOAD_DIR: str = "uploads"
# AI 生图配置
SILICONFLOW_API_KEY: str = ""
SILICONFLOW_BASE_URL: str = "https://api.siliconflow.cn/v1"
VOLCENGINE_API_KEY: str = ""
VOLCENGINE_BASE_URL: str = "https://ark.cn-beijing.volces.com/api/v3"
AI_IMAGE_MODEL: str = "flux-dev" # flux-dev 或 seedream-4.5
AI_IMAGE_SIZE: int = 1024
class Config:
env_file = ".env"

View File

@@ -10,6 +10,7 @@ from fastapi.staticfiles import StaticFiles
from .config import settings
from .routers import categories, designs, users
from .routers import auth
from .routers import admin
@asynccontextmanager
@@ -62,6 +63,7 @@ app.include_router(auth.router)
app.include_router(categories.router)
app.include_router(designs.router)
app.include_router(users.router)
app.include_router(admin.router)
# 配置静态文件服务
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")

View File

@@ -6,6 +6,9 @@ from ..database import Base
from .user import User
from .category import Category, SubType, Color
from .design import Design
from .design_image import DesignImage
from .system_config import SystemConfig
from .prompt_template import PromptTemplate, PromptMapping
__all__ = [
"Base",
@@ -13,5 +16,9 @@ __all__ = [
"Category",
"SubType",
"Color",
"Design"
"Design",
"DesignImage",
"SystemConfig",
"PromptTemplate",
"PromptMapping",
]

View File

@@ -34,6 +34,7 @@ class Design(Base):
category = relationship("Category", back_populates="designs")
sub_type = relationship("SubType", back_populates="designs")
color = relationship("Color", back_populates="designs")
images = relationship("DesignImage", back_populates="design", cascade="all, delete-orphan", order_by="DesignImage.sort_order")
def __repr__(self):
return f"<Design(id={self.id}, status='{self.status}')>"

View File

@@ -0,0 +1,29 @@
"""
设计图片模型
存储每个设计的多视角图片
"""
from sqlalchemy import Column, BigInteger, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from ..database import Base
class DesignImage(Base):
"""设计图片表 - 存储多视角设计图"""
__tablename__ = "design_images"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="图片ID")
design_id = Column(BigInteger, ForeignKey("designs.id", ondelete="CASCADE"), nullable=False, comment="关联设计ID")
view_name = Column(String(20), nullable=False, comment="视角名称: 效果图/正面图/侧面图/背面图")
image_url = Column(String(255), nullable=True, comment="图片URL路径")
model_used = Column(String(50), nullable=True, comment="使用的AI模型: flux-dev/seedream-4.5")
prompt_used = Column(Text, nullable=True, comment="实际使用的英文prompt")
sort_order = Column(Integer, default=0, comment="排序")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
# 关联关系
design = relationship("Design", back_populates="images")
def __repr__(self):
return f"<DesignImage(id={self.id}, view='{self.view_name}')>"

View File

@@ -0,0 +1,37 @@
"""
提示词模板模型
存储可后台配置的提示词模板和映射数据
"""
from sqlalchemy import Column, BigInteger, Integer, String, Text, DateTime
from sqlalchemy.sql import func
from ..database import Base
class PromptTemplate(Base):
"""提示词模板表"""
__tablename__ = "prompt_templates"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="模板ID")
template_key = Column(String(100), unique=True, nullable=False, comment="模板键: main_template / quality_suffix")
template_value = Column(Text, nullable=False, comment="模板内容,支持{变量}占位符")
description = Column(String(255), nullable=True, comment="模板说明")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<PromptTemplate(key='{self.template_key}')>"
class PromptMapping(Base):
"""提示词映射表 - 中文参数到英文描述的映射"""
__tablename__ = "prompt_mappings"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="映射ID")
mapping_type = Column(String(50), nullable=False, comment="映射类型: category/color/carving/style/motif/finish/scene/view/sub_type")
cn_key = Column(String(100), nullable=False, comment="中文键")
en_value = Column(Text, nullable=False, comment="英文描述")
sort_order = Column(Integer, default=0, comment="排序")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<PromptMapping(type='{self.mapping_type}', key='{self.cn_key}')>"

View File

@@ -0,0 +1,24 @@
"""
系统配置模型
存储可通过后台管理的系统配置项
"""
from sqlalchemy import Column, BigInteger, String, Text, DateTime
from sqlalchemy.sql import func
from ..database import Base
class SystemConfig(Base):
"""系统配置表"""
__tablename__ = "system_configs"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="配置ID")
config_key = Column(String(100), unique=True, nullable=False, comment="配置键")
config_value = Column(Text, nullable=True, comment="配置值")
description = Column(String(255), nullable=True, comment="配置说明")
config_group = Column(String(50), nullable=False, default="general", comment="配置分组: ai/general")
is_secret = Column(String(1), nullable=False, default="N", comment="是否敏感信息(Y/N)")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<SystemConfig(key='{self.config_key}')>"

View File

@@ -1,7 +1,7 @@
"""
用户模型
"""
from sqlalchemy import Column, BigInteger, String, DateTime
from sqlalchemy import Column, BigInteger, String, DateTime, Boolean
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
@@ -18,6 +18,7 @@ class User(Base):
hashed_password = Column(String(255), nullable=False, comment="加密密码")
nickname = Column(String(50), nullable=True, comment="昵称")
avatar = Column(String(255), nullable=True, comment="头像URL")
is_admin = Column(Boolean, default=False, nullable=False, comment="是否管理员")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")

View File

@@ -0,0 +1,627 @@
"""
管理后台路由
提供系统配置、用户管理、品类管理、设计管理接口
所有接口需要管理员权限
"""
from datetime import datetime, timedelta
import httpx
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from ..database import get_db
from ..models import User, Design, Category, SubType, Color, SystemConfig, PromptTemplate, PromptMapping
from ..schemas.admin import (
SystemConfigItem, SystemConfigUpdate, SystemConfigResponse,
AdminUserResponse, AdminUserListResponse, AdminSetAdmin,
CategoryCreate, CategoryUpdate, SubTypeCreate, SubTypeUpdate,
ColorCreate, ColorUpdate,
AdminDesignListResponse, DashboardStats,
PromptTemplateItem, PromptTemplateUpdate,
PromptMappingItem, PromptMappingCreate, PromptMappingUpdate
)
from ..utils.deps import get_admin_user
router = APIRouter(prefix="/api/admin", tags=["管理后台"])
# ==================== 仪表盘 ====================
@router.get("/dashboard", response_model=DashboardStats)
def get_dashboard(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取仪表盘统计数据"""
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
return DashboardStats(
total_users=db.query(func.count(User.id)).scalar() or 0,
total_designs=db.query(func.count(Design.id)).scalar() or 0,
total_categories=db.query(func.count(Category.id)).scalar() or 0,
today_designs=db.query(func.count(Design.id)).filter(Design.created_at >= today).scalar() or 0,
today_users=db.query(func.count(User.id)).filter(User.created_at >= today).scalar() or 0,
)
# ==================== 系统配置管理 ====================
@router.get("/configs", response_model=SystemConfigResponse)
def get_configs(
group: str = Query(None, description="按分组筛选"),
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取系统配置列表"""
query = db.query(SystemConfig)
if group:
query = query.filter(SystemConfig.config_group == group)
items = query.order_by(SystemConfig.config_group, SystemConfig.config_key).all()
# 敏感信息脱敏显示
result = []
for item in items:
cfg = SystemConfigItem.model_validate(item)
if item.is_secret == "Y" and item.config_value:
# 只显示前4位和后4位
val = item.config_value
if len(val) > 8:
cfg.config_value = val[:4] + "****" + val[-4:]
else:
cfg.config_value = "****"
result.append(cfg)
return SystemConfigResponse(items=result)
@router.put("/configs")
def update_configs(
data: SystemConfigUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""批量更新系统配置"""
updated = 0
for key, value in data.configs.items():
config = db.query(SystemConfig).filter(SystemConfig.config_key == key).first()
if config:
config.config_value = value
updated += 1
else:
# 自动创建不存在的配置项
new_config = SystemConfig(
config_key=key,
config_value=value,
config_group="general"
)
db.add(new_config)
updated += 1
db.commit()
return {"message": f"已更新 {updated} 项配置"}
@router.post("/configs/init")
def init_default_configs(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""初始化默认配置项(仅当配置表为空时)"""
count = db.query(func.count(SystemConfig.id)).scalar()
if count > 0:
return {"message": "配置已存在,跳过初始化"}
defaults = [
("SILICONFLOW_API_KEY", "", "SiliconFlow API Key", "ai", "Y"),
("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1", "SiliconFlow 接口地址", "ai", "N"),
("VOLCENGINE_API_KEY", "", "火山引擎 API Key", "ai", "Y"),
("VOLCENGINE_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3", "火山引擎接口地址", "ai", "N"),
("AI_IMAGE_MODEL", "flux-dev", "默认AI生图模型 (flux-dev / seedream-4.5)", "ai", "N"),
("AI_IMAGE_SIZE", "1024", "AI生图默认尺寸", "ai", "N"),
]
for key, val, desc, group, secret in defaults:
db.add(SystemConfig(
config_key=key, config_value=val,
description=desc, config_group=group, is_secret=secret
))
db.commit()
return {"message": f"已初始化 {len(defaults)} 项默认配置"}
class TestConnectionRequest(BaseModel):
"""API 连接测试请求"""
provider: str # siliconflow / volcengine
api_key: str
base_url: str
@router.post("/configs/test")
async def test_api_connection(
data: TestConnectionRequest,
admin: User = Depends(get_admin_user)
):
"""测试 AI API 连接是否正常"""
try:
if data.provider == "siliconflow":
url = f"{data.base_url}/models"
headers = {"Authorization": f"Bearer {data.api_key}"}
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.get(url, headers=headers)
if resp.status_code == 200:
return {"message": "连接成功API Key 有效"}
elif resp.status_code == 401:
raise HTTPException(status_code=400, detail="API Key 无效,请检查")
else:
raise HTTPException(status_code=400, detail=f"请求失败,状态码: {resp.status_code}")
elif data.provider == "volcengine":
url = f"{data.base_url}/models"
headers = {"Authorization": f"Bearer {data.api_key}"}
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.get(url, headers=headers)
if resp.status_code == 200:
return {"message": "连接成功API Key 有效"}
elif resp.status_code == 401:
raise HTTPException(status_code=400, detail="API Key 无效,请检查")
else:
# 火山引擎可能返回其他状态码但连接本身成功
return {"message": f"连接成功(状态码: {resp.status_code}"}
else:
raise HTTPException(status_code=400, detail=f"未知的服务提供商: {data.provider}")
except httpx.ConnectError:
raise HTTPException(status_code=400, detail="连接失败,请检查接口地址")
except httpx.TimeoutException:
raise HTTPException(status_code=400, detail="连接超时,请检查网络")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"测试失败: {str(e)}")
# ==================== 用户管理 ====================
@router.get("/users", response_model=AdminUserListResponse)
def get_users(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
keyword: str = Query(None, description="搜索用户名/昵称"),
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取用户列表"""
query = db.query(User)
if keyword:
query = query.filter(
or_(User.username.like(f"%{keyword}%"), User.nickname.like(f"%{keyword}%"))
)
total = query.count()
users = query.order_by(User.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
items = []
for u in users:
design_count = db.query(func.count(Design.id)).filter(Design.user_id == u.id).scalar() or 0
items.append(AdminUserResponse(
id=u.id, username=u.username, nickname=u.nickname,
phone=u.phone, is_admin=u.is_admin,
created_at=u.created_at, design_count=design_count
))
return AdminUserListResponse(items=items, total=total, page=page, page_size=page_size)
@router.put("/users/{user_id}/admin")
def set_user_admin(
user_id: int,
data: AdminSetAdmin,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""设置/取消管理员"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if user.id == admin.id:
raise HTTPException(status_code=400, detail="不能修改自己的管理员状态")
user.is_admin = data.is_admin
db.commit()
return {"message": f"用户 {user.username} {'已设为管理员' if data.is_admin else '已取消管理员'}"}
@router.delete("/users/{user_id}")
def delete_user(
user_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""删除用户"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if user.id == admin.id:
raise HTTPException(status_code=400, detail="不能删除自己")
if user.is_admin:
raise HTTPException(status_code=400, detail="不能删除其他管理员")
db.delete(user)
db.commit()
return {"message": "用户已删除"}
# ==================== 品类管理 ====================
@router.get("/categories")
def get_categories(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取所有品类(含子类型和颜色)"""
categories = db.query(Category).order_by(Category.sort_order).all()
result = []
for cat in categories:
result.append({
"id": cat.id,
"name": cat.name,
"icon": cat.icon,
"sort_order": cat.sort_order,
"flow_type": cat.flow_type,
"sub_types": [{"id": st.id, "name": st.name, "description": st.description,
"preview_image": st.preview_image, "sort_order": st.sort_order}
for st in sorted(cat.sub_types, key=lambda x: x.sort_order)],
"colors": [{"id": c.id, "name": c.name, "hex_code": c.hex_code, "sort_order": c.sort_order}
for c in sorted(cat.colors, key=lambda x: x.sort_order)]
})
return result
@router.post("/categories")
def create_category(
data: CategoryCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""创建品类"""
cat = Category(name=data.name, icon=data.icon, sort_order=data.sort_order, flow_type=data.flow_type)
db.add(cat)
db.commit()
db.refresh(cat)
return {"id": cat.id, "message": "品类创建成功"}
@router.put("/categories/{cat_id}")
def update_category(
cat_id: int,
data: CategoryUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""更新品类"""
cat = db.query(Category).filter(Category.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="品类不存在")
for field, value in data.model_dump(exclude_unset=True).items():
setattr(cat, field, value)
db.commit()
return {"message": "品类更新成功"}
@router.delete("/categories/{cat_id}")
def delete_category(
cat_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""删除品类(级联删除子类型和颜色)"""
cat = db.query(Category).filter(Category.id == cat_id).first()
if not cat:
raise HTTPException(status_code=404, detail="品类不存在")
# 检查是否有关联设计
design_count = db.query(func.count(Design.id)).filter(Design.category_id == cat_id).scalar()
if design_count > 0:
raise HTTPException(status_code=400, detail=f"品类下有 {design_count} 个设计,无法删除")
# 删除子类型和颜色
db.query(SubType).filter(SubType.category_id == cat_id).delete()
db.query(Color).filter(Color.category_id == cat_id).delete()
db.delete(cat)
db.commit()
return {"message": "品类已删除"}
# -- 子类型 --
@router.post("/sub-types")
def create_sub_type(
data: SubTypeCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""创建子类型"""
cat = db.query(Category).filter(Category.id == data.category_id).first()
if not cat:
raise HTTPException(status_code=404, detail="品类不存在")
st = SubType(category_id=data.category_id, name=data.name,
description=data.description, preview_image=data.preview_image,
sort_order=data.sort_order)
db.add(st)
db.commit()
db.refresh(st)
return {"id": st.id, "message": "子类型创建成功"}
@router.put("/sub-types/{st_id}")
def update_sub_type(
st_id: int,
data: SubTypeUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""更新子类型"""
st = db.query(SubType).filter(SubType.id == st_id).first()
if not st:
raise HTTPException(status_code=404, detail="子类型不存在")
for field, value in data.model_dump(exclude_unset=True).items():
setattr(st, field, value)
db.commit()
return {"message": "子类型更新成功"}
@router.delete("/sub-types/{st_id}")
def delete_sub_type(
st_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""删除子类型"""
st = db.query(SubType).filter(SubType.id == st_id).first()
if not st:
raise HTTPException(status_code=404, detail="子类型不存在")
db.delete(st)
db.commit()
return {"message": "子类型已删除"}
# -- 颜色 --
@router.post("/colors")
def create_color(
data: ColorCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""创建颜色"""
cat = db.query(Category).filter(Category.id == data.category_id).first()
if not cat:
raise HTTPException(status_code=404, detail="品类不存在")
color = Color(category_id=data.category_id, name=data.name,
hex_code=data.hex_code, sort_order=data.sort_order)
db.add(color)
db.commit()
db.refresh(color)
return {"id": color.id, "message": "颜色创建成功"}
@router.put("/colors/{color_id}")
def update_color(
color_id: int,
data: ColorUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""更新颜色"""
color = db.query(Color).filter(Color.id == color_id).first()
if not color:
raise HTTPException(status_code=404, detail="颜色不存在")
for field, value in data.model_dump(exclude_unset=True).items():
setattr(color, field, value)
db.commit()
return {"message": "颜色更新成功"}
@router.delete("/colors/{color_id}")
def delete_color(
color_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""删除颜色"""
color = db.query(Color).filter(Color.id == color_id).first()
if not color:
raise HTTPException(status_code=404, detail="颜色不存在")
db.delete(color)
db.commit()
return {"message": "颜色已删除"}
# ==================== 设计管理 ====================
@router.get("/designs", response_model=AdminDesignListResponse)
def get_all_designs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user_id: int = Query(None, description="按用户筛选"),
status_filter: str = Query(None, alias="status", description="按状态筛选"),
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取所有设计列表"""
query = db.query(Design)
if user_id:
query = query.filter(Design.user_id == user_id)
if status_filter:
query = query.filter(Design.status == status_filter)
total = query.count()
designs = query.order_by(Design.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
items = []
for d in designs:
items.append({
"id": d.id,
"user_id": d.user_id,
"username": d.user.username if d.user else None,
"category_name": d.category.name if d.category else None,
"sub_type_name": d.sub_type.name if d.sub_type else None,
"color_name": d.color.name if d.color else None,
"prompt": d.prompt,
"image_url": d.image_url,
"status": d.status,
"created_at": d.created_at.isoformat() if d.created_at else None,
})
return AdminDesignListResponse(items=items, total=total, page=page, page_size=page_size)
@router.delete("/designs/{design_id}")
def admin_delete_design(
design_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""管理员删除任意设计"""
design = db.query(Design).filter(Design.id == design_id).first()
if not design:
raise HTTPException(status_code=404, detail="设计不存在")
db.delete(design)
db.commit()
return {"message": "设计已删除"}
# ==================== 提示词管理 ====================
@router.get("/prompt-templates")
def get_prompt_templates(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取所有提示词模板"""
templates = db.query(PromptTemplate).order_by(PromptTemplate.template_key).all()
return [PromptTemplateItem.model_validate(t) for t in templates]
@router.put("/prompt-templates/{template_id}")
def update_prompt_template(
template_id: int,
data: PromptTemplateUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""更新提示词模板"""
tpl = db.query(PromptTemplate).filter(PromptTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="模板不存在")
tpl.template_value = data.template_value
if data.description is not None:
tpl.description = data.description
db.commit()
return {"message": f"模板 '{tpl.template_key}' 更新成功"}
@router.get("/prompt-mappings")
def get_prompt_mappings(
mapping_type: str = Query(None, description="按映射类型筛选"),
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取提示词映射列表"""
query = db.query(PromptMapping)
if mapping_type:
query = query.filter(PromptMapping.mapping_type == mapping_type)
mappings = query.order_by(PromptMapping.mapping_type, PromptMapping.sort_order).all()
return [PromptMappingItem.model_validate(m) for m in mappings]
@router.get("/prompt-mappings/types")
def get_mapping_types(
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""获取所有映射类型及其数量"""
from sqlalchemy import distinct
types = db.query(
PromptMapping.mapping_type,
func.count(PromptMapping.id)
).group_by(PromptMapping.mapping_type).all()
return [{"type": t, "count": c, "label": {
"category": "品类", "color": "颜色", "view": "视角",
"carving": "雕刻工艺", "style": "设计风格", "motif": "题材纹样",
"finish": "表面处理", "scene": "用途场景", "sub_type": "子类型"
}.get(t, t)} for t, c in types]
@router.post("/prompt-mappings")
def create_prompt_mapping(
data: PromptMappingCreate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""创建提示词映射"""
# 检查重复
existing = db.query(PromptMapping).filter(
PromptMapping.mapping_type == data.mapping_type,
PromptMapping.cn_key == data.cn_key
).first()
if existing:
raise HTTPException(status_code=400, detail=f"映射 '{data.cn_key}' 已存在")
mapping = PromptMapping(
mapping_type=data.mapping_type,
cn_key=data.cn_key,
en_value=data.en_value,
sort_order=data.sort_order
)
db.add(mapping)
db.commit()
db.refresh(mapping)
return {"id": mapping.id, "message": "映射创建成功"}
@router.put("/prompt-mappings/{mapping_id}")
def update_prompt_mapping(
mapping_id: int,
data: PromptMappingUpdate,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""更新提示词映射"""
mapping = db.query(PromptMapping).filter(PromptMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
for field, value in data.model_dump(exclude_unset=True).items():
setattr(mapping, field, value)
db.commit()
return {"message": "映射更新成功"}
@router.delete("/prompt-mappings/{mapping_id}")
def delete_prompt_mapping(
mapping_id: int,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""删除提示词映射"""
mapping = db.query(PromptMapping).filter(PromptMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
db.delete(mapping)
db.commit()
return {"message": "映射已删除"}
@router.post("/prompt-preview")
def preview_prompt(
params: dict,
db: Session = Depends(get_db),
admin: User = Depends(get_admin_user)
):
"""预览提示词生成结果"""
from ..services.prompt_builder import build_prompt
try:
prompt = build_prompt(
category_name=params.get("category_name", "牌子"),
view_name=params.get("view_name", "效果图"),
sub_type_name=params.get("sub_type_name"),
color_name=params.get("color_name"),
user_prompt=params.get("user_prompt"),
carving_technique=params.get("carving_technique"),
design_style=params.get("design_style"),
motif=params.get("motif"),
size_spec=params.get("size_spec"),
surface_finish=params.get("surface_finish"),
usage_scene=params.get("usage_scene"),
)
return {"prompt": prompt}
except Exception as e:
raise HTTPException(status_code=400, detail=f"提示词生成失败: {str(e)}")

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from ..database import get_db
from ..models import User, Design
from ..schemas import DesignCreate, DesignResponse, DesignListResponse
from ..schemas import DesignCreate, DesignResponse, DesignListResponse, DesignImageResponse
from ..utils.deps import get_current_user
from ..services import design_service
@@ -18,6 +18,21 @@ router = APIRouter(prefix="/api/designs", tags=["设计"])
def design_to_response(design: Design) -> DesignResponse:
"""将 Design 模型转换为响应格式"""
# 构建多视角图片列表
images = []
if hasattr(design, 'images') and design.images:
images = [
DesignImageResponse(
id=img.id,
view_name=img.view_name,
image_url=img.image_url,
model_used=img.model_used,
prompt_used=img.prompt_used,
sort_order=img.sort_order,
)
for img in design.images
]
return DesignResponse(
id=design.id,
user_id=design.user_id,
@@ -51,6 +66,7 @@ def design_to_response(design: Design) -> DesignResponse:
surface_finish=design.surface_finish,
usage_scene=design.usage_scene,
image_url=design.image_url,
images=images,
status=design.status,
created_at=design.created_at,
updated_at=design.updated_at
@@ -58,17 +74,17 @@ def design_to_response(design: Design) -> DesignResponse:
@router.post("/generate", response_model=DesignResponse)
def generate_design(
async def generate_design(
design_data: DesignCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
提交设计生成请求
提交设计生成请求(异步,支持 AI 多视角生图)
需要认证
"""
try:
design = design_service.create_design(
design = await design_service.create_design_async(
db=db,
user_id=current_user.id,
design_data=design_data

View File

@@ -4,7 +4,7 @@ Pydantic Schemas
"""
from .user import UserCreate, UserLogin, UserResponse, Token, UserUpdate, PasswordChange
from .category import CategoryResponse, SubTypeResponse, ColorResponse
from .design import DesignCreate, DesignResponse, DesignListResponse
from .design import DesignCreate, DesignResponse, DesignListResponse, DesignImageResponse
__all__ = [
# User schemas
@@ -22,4 +22,5 @@ __all__ = [
"DesignCreate",
"DesignResponse",
"DesignListResponse",
"DesignImageResponse",
]

View File

@@ -0,0 +1,173 @@
"""
管理后台相关 Pydantic Schemas
"""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional, List
# ============ 系统配置 ============
class SystemConfigItem(BaseModel):
"""单个配置项"""
config_key: str
config_value: Optional[str] = None
description: Optional[str] = None
config_group: str = "general"
is_secret: str = "N"
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class SystemConfigUpdate(BaseModel):
"""批量更新配置"""
configs: dict = Field(..., description="键值对: {config_key: config_value}")
class SystemConfigResponse(BaseModel):
"""配置列表响应"""
items: List[SystemConfigItem]
# ============ 用户管理 ============
class AdminUserResponse(BaseModel):
"""管理端用户信息"""
id: int
username: str
nickname: Optional[str] = None
phone: Optional[str] = None
is_admin: bool = False
created_at: datetime
design_count: int = 0
class Config:
from_attributes = True
class AdminUserListResponse(BaseModel):
"""用户列表响应"""
items: List[AdminUserResponse]
total: int
page: int
page_size: int
class AdminSetAdmin(BaseModel):
"""设置管理员"""
is_admin: bool
# ============ 品类管理 ============
class CategoryCreate(BaseModel):
"""创建品类"""
name: str = Field(..., max_length=50)
icon: Optional[str] = Field(None, max_length=255)
sort_order: int = 0
flow_type: str = Field("full", max_length=20)
class CategoryUpdate(BaseModel):
"""更新品类"""
name: Optional[str] = Field(None, max_length=50)
icon: Optional[str] = Field(None, max_length=255)
sort_order: Optional[int] = None
flow_type: Optional[str] = Field(None, max_length=20)
class SubTypeCreate(BaseModel):
"""创建子类型"""
category_id: int
name: str = Field(..., max_length=50)
description: Optional[str] = Field(None, max_length=255)
preview_image: Optional[str] = Field(None, max_length=255)
sort_order: int = 0
class SubTypeUpdate(BaseModel):
"""更新子类型"""
name: Optional[str] = Field(None, max_length=50)
description: Optional[str] = Field(None, max_length=255)
preview_image: Optional[str] = Field(None, max_length=255)
sort_order: Optional[int] = None
class ColorCreate(BaseModel):
"""创建颜色"""
category_id: int
name: str = Field(..., max_length=50)
hex_code: Optional[str] = Field(None, max_length=10)
sort_order: int = 0
class ColorUpdate(BaseModel):
"""更新颜色"""
name: Optional[str] = Field(None, max_length=50)
hex_code: Optional[str] = Field(None, max_length=10)
sort_order: Optional[int] = None
# ============ 设计管理 ============
class AdminDesignListResponse(BaseModel):
"""管理端设计列表"""
items: list
total: int
page: int
page_size: int
# ============ 提示词管理 ============
class PromptTemplateItem(BaseModel):
"""提示词模板"""
id: Optional[int] = None
template_key: str
template_value: str
description: Optional[str] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class PromptTemplateUpdate(BaseModel):
"""更新提示词模板"""
template_value: str
description: Optional[str] = None
class PromptMappingItem(BaseModel):
"""提示词映射"""
id: Optional[int] = None
mapping_type: str
cn_key: str
en_value: str
sort_order: int = 0
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class PromptMappingCreate(BaseModel):
"""创建提示词映射"""
mapping_type: str
cn_key: str
en_value: str
sort_order: int = 0
class PromptMappingUpdate(BaseModel):
"""更新提示词映射"""
cn_key: Optional[str] = None
en_value: Optional[str] = None
sort_order: Optional[int] = None
# ============ 仪表盘 ============
class DashboardStats(BaseModel):
"""仪表盘统计"""
total_users: int
total_designs: int
total_categories: int
today_designs: int
today_users: int

View File

@@ -8,6 +8,20 @@ from typing import Optional, List
from .category import CategoryResponse, SubTypeResponse, ColorResponse
class DesignImageResponse(BaseModel):
"""设计图片响应(单张视角图)"""
id: int
view_name: str
image_url: Optional[str] = None
model_used: Optional[str] = None
prompt_used: Optional[str] = None
sort_order: int = 0
class Config:
from_attributes = True
protected_namespaces = ()
class DesignCreate(BaseModel):
"""创建设计请求"""
category_id: int = Field(..., description="品类ID")
@@ -37,6 +51,7 @@ class DesignResponse(BaseModel):
surface_finish: Optional[str] = None
usage_scene: Optional[str] = None
image_url: Optional[str] = None
images: List[DesignImageResponse] = []
status: str
created_at: datetime
updated_at: datetime

View File

@@ -26,6 +26,7 @@ class UserResponse(BaseModel):
nickname: Optional[str] = None
phone: Optional[str] = None
avatar: Optional[str] = None
is_admin: bool = False
created_at: datetime
class Config:

View File

@@ -0,0 +1,144 @@
"""
AI 生图服务
支持双模型SiliconFlow FLUX.1 [dev] 和 火山引擎 Seedream 4.5
"""
import os
import uuid
import logging
import httpx
from typing import Optional
from ..config import settings
from .config_service import get_ai_config
logger = logging.getLogger(__name__)
# 超时设置(秒)
REQUEST_TIMEOUT = 90
# 最大重试次数
MAX_RETRIES = 3
async def _call_siliconflow(prompt: str, size: int = 1024, ai_config: dict = None) -> str:
"""
调用 SiliconFlow FLUX.1 [dev] 生图 API
Returns:
远程图片 URL
"""
cfg = ai_config or get_ai_config()
url = f"{cfg['SILICONFLOW_BASE_URL']}/images/generations"
headers = {
"Authorization": f"Bearer {cfg['SILICONFLOW_API_KEY']}",
"Content-Type": "application/json",
}
payload = {
"model": "black-forest-labs/FLUX.1-dev",
"prompt": prompt,
"image_size": f"{size}x{size}",
"num_inference_steps": 20,
}
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
resp = await client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
# SiliconFlow 响应格式: {"images": [{"url": "https://..."}]}
images = data.get("images", [])
if not images:
raise ValueError("SiliconFlow 返回空图片列表")
return images[0]["url"]
async def _call_seedream(prompt: str, size: int = 1024, ai_config: dict = None) -> str:
"""
调用火山引擎 Seedream 4.5 生图 API
Returns:
远程图片 URL
"""
cfg = ai_config or get_ai_config()
url = f"{cfg['VOLCENGINE_BASE_URL']}/images/generations"
headers = {
"Authorization": f"Bearer {cfg['VOLCENGINE_API_KEY']}",
"Content-Type": "application/json",
}
payload = {
"model": "doubao-seedream-4.5-t2i-250528",
"prompt": prompt,
"size": f"{size}x{size}",
"response_format": "url",
}
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
resp = await client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
# Seedream 响应格式: {"data": [{"url": "https://..."}]}
items = data.get("data", [])
if not items:
raise ValueError("Seedream 返回空图片列表")
return items[0]["url"]
async def generate_image(prompt: str, model: Optional[str] = None) -> str:
"""
统一生图接口,带重试机制
Args:
prompt: 英文提示词
model: 模型名称 (flux-dev / seedream-4.5),为空则使用配置默认值
Returns:
远程图片 URL
Raises:
Exception: 所有重试失败后抛出
"""
ai_config = get_ai_config()
model = model or ai_config.get("AI_IMAGE_MODEL", "flux-dev")
size = ai_config.get("AI_IMAGE_SIZE", 1024)
last_error: Optional[Exception] = None
for attempt in range(1, MAX_RETRIES + 1):
try:
if model == "seedream-4.5":
image_url = await _call_seedream(prompt, size, ai_config)
else:
image_url = await _call_siliconflow(prompt, size, ai_config)
logger.info(f"AI 生图成功 (model={model}, attempt={attempt})")
return image_url
except Exception as e:
last_error = e
logger.warning(f"AI 生图失败 (model={model}, attempt={attempt}/{MAX_RETRIES}): {e}")
if attempt < MAX_RETRIES:
import asyncio
await asyncio.sleep(2 * attempt) # 指数退避
raise RuntimeError(f"AI 生图在 {MAX_RETRIES} 次重试后仍然失败: {last_error}")
async def download_and_save(image_url: str, save_path: str) -> str:
"""
下载远程图片并保存到本地
Args:
image_url: 远程图片 URL
save_path: 本地保存路径(如 uploads/designs/1001_效果图.png
Returns:
本地文件相对路径(以 / 开头,如 /uploads/designs/1001_效果图.png
"""
# 确保目录存在
os.makedirs(os.path.dirname(save_path), exist_ok=True)
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
resp = await client.get(image_url)
resp.raise_for_status()
with open(save_path, "wb") as f:
f.write(resp.content)
logger.info(f"图片已下载保存: {save_path}")
return f"/{save_path}"

View File

@@ -0,0 +1,58 @@
"""
配置服务
优先从数据库 system_configs 表读取配置,数据库无值时回退到 .env
"""
import logging
from typing import Optional
from sqlalchemy.orm import Session
from ..database import SessionLocal
from ..models.system_config import SystemConfig
from ..config import settings as env_settings
logger = logging.getLogger(__name__)
def get_config_value(key: str, default: Optional[str] = None) -> Optional[str]:
"""
获取配置值(数据库优先,.env 兜底)
Args:
key: 配置键名(如 SILICONFLOW_API_KEY
default: 默认值
Returns:
配置值字符串
"""
# 1. 尝试从数据库读取
try:
db = SessionLocal()
try:
config = db.query(SystemConfig).filter(SystemConfig.config_key == key).first()
if config and config.config_value:
return config.config_value
finally:
db.close()
except Exception as e:
logger.warning(f"从数据库读取配置 {key} 失败: {e}")
# 2. 回退到 .env / Settings
env_value = getattr(env_settings, key, None)
if env_value is not None and env_value != "":
return str(env_value)
return default
def get_ai_config() -> dict:
"""
获取所有 AI 相关配置
返回字典,方便 ai_generator 使用
"""
return {
"SILICONFLOW_API_KEY": get_config_value("SILICONFLOW_API_KEY", ""),
"SILICONFLOW_BASE_URL": get_config_value("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1"),
"VOLCENGINE_API_KEY": get_config_value("VOLCENGINE_API_KEY", ""),
"VOLCENGINE_BASE_URL": get_config_value("VOLCENGINE_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"),
"AI_IMAGE_MODEL": get_config_value("AI_IMAGE_MODEL", "flux-dev"),
"AI_IMAGE_SIZE": int(get_config_value("AI_IMAGE_SIZE", "1024")),
}

View File

@@ -1,40 +1,56 @@
"""
设计服务
处理设计相关的业务逻辑
处理设计相关的业务逻辑,支持 AI 多视角生图 + mock 降级
"""
import os
import logging
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import desc
from ..models import Design, Category, SubType, Color
from ..models import Design, DesignImage, Category, SubType, Color
from ..schemas import DesignCreate
from ..config import settings
from .mock_generator import generate_mock_design
from .prompt_builder import get_views_for_category, build_prompt
from . import ai_generator
logger = logging.getLogger(__name__)
def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design:
def _has_ai_key() -> bool:
"""检查是否配置了 AI API Key"""
model = settings.AI_IMAGE_MODEL
if model == "seedream-4.5":
return bool(settings.VOLCENGINE_API_KEY)
return bool(settings.SILICONFLOW_API_KEY)
async def create_design_async(db: Session, user_id: int, design_data: DesignCreate) -> Design:
"""
创建设计记录
1. 创建设计记录status=generating
2. 调用 mock_generator 生成图片
3. 更新设计记录status=completed, image_url
4. 返回设计对象
创建设计记录(异步版本,支持 AI 多视角生图)
流程:
1. 创建 Design 记录status=generating
2. 获取品类视角列表
3. 循环每个视角:构建 prompt → 调用 AI 生图 → 下载保存 → 创建 DesignImage
4. 第一张效果图 URL 存入 design.image_url兼容旧逻辑
5. 更新 status=completed
6. 失败时降级到 mock_generator
"""
# 获取关联信息
category = db.query(Category).filter(Category.id == design_data.category_id).first()
if not category:
raise ValueError(f"品类不存在: {design_data.category_id}")
sub_type = None
if design_data.sub_type_id:
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
color = None
if design_data.color_id:
color = db.query(Color).filter(Color.id == design_data.color_id).first()
# 创建设计记录
design = Design(
user_id=user_id,
@@ -52,8 +68,109 @@ def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Desig
)
db.add(design)
db.flush() # 获取 ID
# 生成图片
# 尝试 AI 生图
if _has_ai_key():
try:
await _generate_ai_images(db, design, category, sub_type, color, design_data)
db.commit()
db.refresh(design)
return design
except Exception as e:
logger.error(f"AI 生图全部失败,降级到 mock: {e}")
db.rollback()
# 重新查询,因为 rollback 后 ORM 对象可能失效
design = db.query(Design).filter(Design.id == design.id).first()
if not design:
# rollback 导致 design 也没了,重新创建
design = Design(
user_id=user_id,
category_id=design_data.category_id,
sub_type_id=design_data.sub_type_id,
color_id=design_data.color_id,
prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
status="generating"
)
db.add(design)
db.flush()
# 降级到 mock 生成
_generate_mock_fallback(db, design, category, sub_type, color, design_data)
db.commit()
db.refresh(design)
return design
async def _generate_ai_images(
db: Session,
design: Design,
category,
sub_type,
color,
design_data: DesignCreate,
) -> None:
"""使用 AI 模型为每个视角生成图片"""
views = get_views_for_category(category.name)
model = settings.AI_IMAGE_MODEL
for idx, view_name in enumerate(views):
# 构建 prompt
prompt_text = build_prompt(
category_name=category.name,
view_name=view_name,
sub_type_name=sub_type.name if sub_type else None,
color_name=color.name if color else None,
user_prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
)
# 调用 AI 生图
remote_url = await ai_generator.generate_image(prompt_text, model)
# 下载保存到本地
save_path = os.path.join(
settings.UPLOAD_DIR, "designs", f"{design.id}_{view_name}.png"
)
local_url = await ai_generator.download_and_save(remote_url, save_path)
# 创建 DesignImage 记录
design_image = DesignImage(
design_id=design.id,
view_name=view_name,
image_url=local_url,
model_used=model,
prompt_used=prompt_text,
sort_order=idx,
)
db.add(design_image)
# 第一张图(效果图)存入 design.image_url 兼容旧逻辑
if idx == 0:
design.image_url = local_url
design.status = "completed"
def _generate_mock_fallback(
db: Session,
design: Design,
category,
sub_type,
color,
design_data: DesignCreate,
) -> None:
"""降级使用 mock 生成器"""
save_path = os.path.join(settings.UPLOAD_DIR, "designs", f"{design.id}.png")
image_url = generate_mock_design(
category_name=category.name,
@@ -68,13 +185,47 @@ def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Desig
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
)
# 更新设计记录
design.image_url = image_url
design.status = "completed"
logger.info(f"Mock 降级生成完成: design_id={design.id}")
def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design:
"""
同步版本创建设计(兼容旧调用,仅用 mock
"""
category = db.query(Category).filter(Category.id == design_data.category_id).first()
if not category:
raise ValueError(f"品类不存在: {design_data.category_id}")
sub_type = None
if design_data.sub_type_id:
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
color = None
if design_data.color_id:
color = db.query(Color).filter(Color.id == design_data.color_id).first()
design = Design(
user_id=user_id,
category_id=design_data.category_id,
sub_type_id=design_data.sub_type_id,
color_id=design_data.color_id,
prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
status="generating"
)
db.add(design)
db.flush()
_generate_mock_fallback(db, design, category, sub_type, color, design_data)
db.commit()
db.refresh(design)
return design
@@ -132,16 +283,24 @@ def delete_design(db: Session, design_id: int, user_id: int) -> bool:
if not design:
return False
# 删除图片文件
# 删除图片文件
if design.image_url:
# image_url 格式: /uploads/designs/1001.png
# 转换为实际文件路径
file_path = design.image_url.lstrip("/")
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception:
pass # 忽略删除失败
pass
# 删除多视角图片文件
for img in design.images:
if img.image_url:
fp = img.image_url.lstrip("/")
if os.path.exists(fp):
try:
os.remove(fp)
except Exception:
pass
# 删除数据库记录
db.delete(design)

View File

@@ -0,0 +1,164 @@
"""
专业玉雕设计提示词构建器(数据库版)
从数据库 prompt_templates + prompt_mappings 读取配置,支持后台热更新
"""
import logging
from typing import Optional, Dict, List
from ..database import SessionLocal
from ..models.prompt_template import PromptTemplate, PromptMapping
logger = logging.getLogger(__name__)
# ============================================================
# 品类视角配置(保留硬编码,因为与业务流程强关联)
# ============================================================
CATEGORY_VIEWS: Dict[str, List[str]] = {
"牌子": ["效果图", "正面图", "背面图"],
"珠子": ["效果图", "正面图"],
"手把件": ["效果图", "正面图", "侧面图", "背面图"],
"雕刻件": ["效果图", "正面图", "侧面图", "背面图"],
"摆件": ["效果图", "正面图", "侧面图", "背面图"],
"手镯": ["效果图", "正面图", "侧面图"],
"耳钉": ["效果图", "正面图"],
"耳饰": ["效果图", "正面图"],
"手链": ["效果图", "正面图"],
"项链": ["效果图", "正面图"],
"戒指": ["效果图", "正面图", "侧面图"],
"表带": ["效果图", "正面图"],
}
def _load_mappings(mapping_type: str) -> Dict[str, str]:
"""从数据库加载指定类型的映射字典"""
try:
db = SessionLocal()
try:
rows = db.query(PromptMapping).filter(
PromptMapping.mapping_type == mapping_type
).order_by(PromptMapping.sort_order).all()
return {r.cn_key: r.en_value for r in rows}
finally:
db.close()
except Exception as e:
logger.warning(f"加载映射 {mapping_type} 失败: {e}")
return {}
def _load_template(template_key: str, default: str = "") -> str:
"""从数据库加载模板"""
try:
db = SessionLocal()
try:
tpl = db.query(PromptTemplate).filter(
PromptTemplate.template_key == template_key
).first()
if tpl:
return tpl.template_value
finally:
db.close()
except Exception as e:
logger.warning(f"加载模板 {template_key} 失败: {e}")
return default
def get_views_for_category(category_name: str) -> List[str]:
"""获取品类对应的视角列表"""
return CATEGORY_VIEWS.get(category_name, ["效果图", "正面图"])
def build_prompt(
category_name: str,
view_name: str,
sub_type_name: Optional[str] = None,
color_name: Optional[str] = None,
user_prompt: Optional[str] = None,
carving_technique: Optional[str] = None,
design_style: Optional[str] = None,
motif: Optional[str] = None,
size_spec: Optional[str] = None,
surface_finish: Optional[str] = None,
usage_scene: Optional[str] = None,
) -> str:
"""
构建专业英文生图提示词(从数据库读取映射和模板)
业务逻辑:用户参数 → 中英映射 → 填入模板 → 最终prompt
"""
# 从数据库加载所有映射
category_map = _load_mappings("category")
color_map = _load_mappings("color")
view_map = _load_mappings("view")
carving_map = _load_mappings("carving")
style_map = _load_mappings("style")
motif_map = _load_mappings("motif")
finish_map = _load_mappings("finish")
scene_map = _load_mappings("scene")
sub_type_map = _load_mappings("sub_type")
# 加载模板
quality_suffix = _load_template("quality_suffix",
"professional jewelry product photography, studio lighting setup, pure white background, ultra-detailed, sharp focus, 8K resolution, photorealistic rendering, high-end commercial quality")
default_color = _load_template("default_color",
"natural Hetian nephrite jade with warm luster")
# 构建各部分
parts = []
# 1. 品类主体
subject = category_map.get(category_name, f"Chinese Hetian nephrite jade {category_name}")
parts.append(subject)
# 2. 子类型
if sub_type_name:
sub_detail = sub_type_map.get(sub_type_name, sub_type_name)
parts.append(sub_detail)
# 3. 颜色
if color_name:
color_desc = color_map.get(color_name, f"{color_name} colored nephrite jade")
parts.append(color_desc)
else:
parts.append(default_color)
# 4. 题材纹样
if motif:
motif_desc = motif_map.get(motif, f"{motif} themed design")
parts.append(f"featuring {motif_desc}")
# 5. 雕刻工艺
if carving_technique:
tech_desc = carving_map.get(carving_technique, carving_technique)
parts.append(tech_desc)
# 6. 设计风格
if design_style:
style_desc = style_map.get(design_style, design_style)
parts.append(style_desc)
# 7. 表面处理
if surface_finish:
finish_desc = finish_map.get(surface_finish, surface_finish)
parts.append(finish_desc)
# 8. 用途场景
if usage_scene:
scene_desc = scene_map.get(usage_scene, usage_scene)
parts.append(scene_desc)
# 9. 尺寸
if size_spec:
parts.append(f"size approximately {size_spec}")
# 10. 用户描述
if user_prompt:
parts.append(f"design concept: {user_prompt}")
# 11. 视角
view_desc = view_map.get(view_name, "three-quarter view")
parts.append(view_desc)
# 12. 质量后缀
parts.append(quality_suffix)
return ", ".join(parts)

View File

@@ -56,3 +56,19 @@ def get_current_user(
raise credentials_exception
return user
def get_admin_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
获取当前管理员用户
验证当前用户是否为管理员,非管理员抛出 403
"""
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足,需要管理员权限"
)
return current_user