feat(ai): 支持双模型多视角AI设计生图与后台管理系统
- 实现AI多视角设计图生成功能,支持6个可选设计参数配置 - 集成SiliconFlow FLUX.1与火山引擎Seedream 4.5双模型切换 - 构建专业中文转英文prompt系统,提升AI生成质量 - 前端设计预览支持多视角切换与视角指示器展示 - 增加多视角设计图片DesignImage模型关联及存储 - 后端设计服务异步调用AI接口,失败时降级生成mock图 - 新增管理员后台管理路由及完整的权限校验机制 - 实现后台模块:仪表盘、系统配置、用户/品类/设计管理 - 配置数据库系统配置表,支持动态AI配置及热更新 - 增加用户管理员标识字段,管理后台登录鉴权支持 - 更新API接口支持多视角设计参数及后台管理接口 - 优化设计删除逻辑,删除多视角相关图片文件 - 前端新增管理后台页面与路由,布局样式独立分离 - 更新环境变量增加AI模型相关Key与参数配置说明 - 引入httpx异步HTTP客户端用于AI接口调用及图片下载 - README文档完善AI多视角生图与后台管理详细功能与流程说明
This commit is contained in:
@@ -13,6 +13,13 @@ class Settings(BaseSettings):
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440
|
||||
UPLOAD_DIR: str = "uploads"
|
||||
# AI 生图配置
|
||||
SILICONFLOW_API_KEY: str = ""
|
||||
SILICONFLOW_BASE_URL: str = "https://api.siliconflow.cn/v1"
|
||||
VOLCENGINE_API_KEY: str = ""
|
||||
VOLCENGINE_BASE_URL: str = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
AI_IMAGE_MODEL: str = "flux-dev" # flux-dev 或 seedream-4.5
|
||||
AI_IMAGE_SIZE: int = 1024
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from .config import settings
|
||||
from .routers import categories, designs, users
|
||||
from .routers import auth
|
||||
from .routers import admin
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -62,6 +63,7 @@ app.include_router(auth.router)
|
||||
app.include_router(categories.router)
|
||||
app.include_router(designs.router)
|
||||
app.include_router(users.router)
|
||||
app.include_router(admin.router)
|
||||
|
||||
# 配置静态文件服务
|
||||
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")
|
||||
|
||||
@@ -6,6 +6,9 @@ from ..database import Base
|
||||
from .user import User
|
||||
from .category import Category, SubType, Color
|
||||
from .design import Design
|
||||
from .design_image import DesignImage
|
||||
from .system_config import SystemConfig
|
||||
from .prompt_template import PromptTemplate, PromptMapping
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
@@ -13,5 +16,9 @@ __all__ = [
|
||||
"Category",
|
||||
"SubType",
|
||||
"Color",
|
||||
"Design"
|
||||
"Design",
|
||||
"DesignImage",
|
||||
"SystemConfig",
|
||||
"PromptTemplate",
|
||||
"PromptMapping",
|
||||
]
|
||||
|
||||
@@ -34,6 +34,7 @@ class Design(Base):
|
||||
category = relationship("Category", back_populates="designs")
|
||||
sub_type = relationship("SubType", back_populates="designs")
|
||||
color = relationship("Color", back_populates="designs")
|
||||
images = relationship("DesignImage", back_populates="design", cascade="all, delete-orphan", order_by="DesignImage.sort_order")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Design(id={self.id}, status='{self.status}')>"
|
||||
|
||||
29
backend/app/models/design_image.py
Normal file
29
backend/app/models/design_image.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
设计图片模型
|
||||
存储每个设计的多视角图片
|
||||
"""
|
||||
from sqlalchemy import Column, BigInteger, Integer, String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class DesignImage(Base):
|
||||
"""设计图片表 - 存储多视角设计图"""
|
||||
__tablename__ = "design_images"
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="图片ID")
|
||||
design_id = Column(BigInteger, ForeignKey("designs.id", ondelete="CASCADE"), nullable=False, comment="关联设计ID")
|
||||
view_name = Column(String(20), nullable=False, comment="视角名称: 效果图/正面图/侧面图/背面图")
|
||||
image_url = Column(String(255), nullable=True, comment="图片URL路径")
|
||||
model_used = Column(String(50), nullable=True, comment="使用的AI模型: flux-dev/seedream-4.5")
|
||||
prompt_used = Column(Text, nullable=True, comment="实际使用的英文prompt")
|
||||
sort_order = Column(Integer, default=0, comment="排序")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
|
||||
# 关联关系
|
||||
design = relationship("Design", back_populates="images")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DesignImage(id={self.id}, view='{self.view_name}')>"
|
||||
37
backend/app/models/prompt_template.py
Normal file
37
backend/app/models/prompt_template.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
提示词模板模型
|
||||
存储可后台配置的提示词模板和映射数据
|
||||
"""
|
||||
from sqlalchemy import Column, BigInteger, Integer, String, Text, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class PromptTemplate(Base):
|
||||
"""提示词模板表"""
|
||||
__tablename__ = "prompt_templates"
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="模板ID")
|
||||
template_key = Column(String(100), unique=True, nullable=False, comment="模板键: main_template / quality_suffix")
|
||||
template_value = Column(Text, nullable=False, comment="模板内容,支持{变量}占位符")
|
||||
description = Column(String(255), nullable=True, comment="模板说明")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PromptTemplate(key='{self.template_key}')>"
|
||||
|
||||
|
||||
class PromptMapping(Base):
|
||||
"""提示词映射表 - 中文参数到英文描述的映射"""
|
||||
__tablename__ = "prompt_mappings"
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="映射ID")
|
||||
mapping_type = Column(String(50), nullable=False, comment="映射类型: category/color/carving/style/motif/finish/scene/view/sub_type")
|
||||
cn_key = Column(String(100), nullable=False, comment="中文键")
|
||||
en_value = Column(Text, nullable=False, comment="英文描述")
|
||||
sort_order = Column(Integer, default=0, comment="排序")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PromptMapping(type='{self.mapping_type}', key='{self.cn_key}')>"
|
||||
24
backend/app/models/system_config.py
Normal file
24
backend/app/models/system_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
系统配置模型
|
||||
存储可通过后台管理的系统配置项
|
||||
"""
|
||||
from sqlalchemy import Column, BigInteger, String, Text, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""系统配置表"""
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="配置ID")
|
||||
config_key = Column(String(100), unique=True, nullable=False, comment="配置键")
|
||||
config_value = Column(Text, nullable=True, comment="配置值")
|
||||
description = Column(String(255), nullable=True, comment="配置说明")
|
||||
config_group = Column(String(50), nullable=False, default="general", comment="配置分组: ai/general")
|
||||
is_secret = Column(String(1), nullable=False, default="N", comment="是否敏感信息(Y/N)")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SystemConfig(key='{self.config_key}')>"
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
用户模型
|
||||
"""
|
||||
from sqlalchemy import Column, BigInteger, String, DateTime
|
||||
from sqlalchemy import Column, BigInteger, String, DateTime, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -18,6 +18,7 @@ class User(Base):
|
||||
hashed_password = Column(String(255), nullable=False, comment="加密密码")
|
||||
nickname = Column(String(50), nullable=True, comment="昵称")
|
||||
avatar = Column(String(255), nullable=True, comment="头像URL")
|
||||
is_admin = Column(Boolean, default=False, nullable=False, comment="是否管理员")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
|
||||
627
backend/app/routers/admin.py
Normal file
627
backend/app/routers/admin.py
Normal file
@@ -0,0 +1,627 @@
|
||||
"""
|
||||
管理后台路由
|
||||
提供系统配置、用户管理、品类管理、设计管理接口
|
||||
所有接口需要管理员权限
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
from ..database import get_db
|
||||
from ..models import User, Design, Category, SubType, Color, SystemConfig, PromptTemplate, PromptMapping
|
||||
from ..schemas.admin import (
|
||||
SystemConfigItem, SystemConfigUpdate, SystemConfigResponse,
|
||||
AdminUserResponse, AdminUserListResponse, AdminSetAdmin,
|
||||
CategoryCreate, CategoryUpdate, SubTypeCreate, SubTypeUpdate,
|
||||
ColorCreate, ColorUpdate,
|
||||
AdminDesignListResponse, DashboardStats,
|
||||
PromptTemplateItem, PromptTemplateUpdate,
|
||||
PromptMappingItem, PromptMappingCreate, PromptMappingUpdate
|
||||
)
|
||||
from ..utils.deps import get_admin_user
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["管理后台"])
|
||||
|
||||
|
||||
# ==================== 仪表盘 ====================
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardStats)
|
||||
def get_dashboard(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取仪表盘统计数据"""
|
||||
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return DashboardStats(
|
||||
total_users=db.query(func.count(User.id)).scalar() or 0,
|
||||
total_designs=db.query(func.count(Design.id)).scalar() or 0,
|
||||
total_categories=db.query(func.count(Category.id)).scalar() or 0,
|
||||
today_designs=db.query(func.count(Design.id)).filter(Design.created_at >= today).scalar() or 0,
|
||||
today_users=db.query(func.count(User.id)).filter(User.created_at >= today).scalar() or 0,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 系统配置管理 ====================
|
||||
|
||||
@router.get("/configs", response_model=SystemConfigResponse)
|
||||
def get_configs(
|
||||
group: str = Query(None, description="按分组筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取系统配置列表"""
|
||||
query = db.query(SystemConfig)
|
||||
if group:
|
||||
query = query.filter(SystemConfig.config_group == group)
|
||||
items = query.order_by(SystemConfig.config_group, SystemConfig.config_key).all()
|
||||
|
||||
# 敏感信息脱敏显示
|
||||
result = []
|
||||
for item in items:
|
||||
cfg = SystemConfigItem.model_validate(item)
|
||||
if item.is_secret == "Y" and item.config_value:
|
||||
# 只显示前4位和后4位
|
||||
val = item.config_value
|
||||
if len(val) > 8:
|
||||
cfg.config_value = val[:4] + "****" + val[-4:]
|
||||
else:
|
||||
cfg.config_value = "****"
|
||||
result.append(cfg)
|
||||
return SystemConfigResponse(items=result)
|
||||
|
||||
|
||||
@router.put("/configs")
|
||||
def update_configs(
|
||||
data: SystemConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""批量更新系统配置"""
|
||||
updated = 0
|
||||
for key, value in data.configs.items():
|
||||
config = db.query(SystemConfig).filter(SystemConfig.config_key == key).first()
|
||||
if config:
|
||||
config.config_value = value
|
||||
updated += 1
|
||||
else:
|
||||
# 自动创建不存在的配置项
|
||||
new_config = SystemConfig(
|
||||
config_key=key,
|
||||
config_value=value,
|
||||
config_group="general"
|
||||
)
|
||||
db.add(new_config)
|
||||
updated += 1
|
||||
db.commit()
|
||||
return {"message": f"已更新 {updated} 项配置"}
|
||||
|
||||
|
||||
@router.post("/configs/init")
|
||||
def init_default_configs(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""初始化默认配置项(仅当配置表为空时)"""
|
||||
count = db.query(func.count(SystemConfig.id)).scalar()
|
||||
if count > 0:
|
||||
return {"message": "配置已存在,跳过初始化"}
|
||||
|
||||
defaults = [
|
||||
("SILICONFLOW_API_KEY", "", "SiliconFlow API Key", "ai", "Y"),
|
||||
("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1", "SiliconFlow 接口地址", "ai", "N"),
|
||||
("VOLCENGINE_API_KEY", "", "火山引擎 API Key", "ai", "Y"),
|
||||
("VOLCENGINE_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3", "火山引擎接口地址", "ai", "N"),
|
||||
("AI_IMAGE_MODEL", "flux-dev", "默认AI生图模型 (flux-dev / seedream-4.5)", "ai", "N"),
|
||||
("AI_IMAGE_SIZE", "1024", "AI生图默认尺寸", "ai", "N"),
|
||||
]
|
||||
for key, val, desc, group, secret in defaults:
|
||||
db.add(SystemConfig(
|
||||
config_key=key, config_value=val,
|
||||
description=desc, config_group=group, is_secret=secret
|
||||
))
|
||||
db.commit()
|
||||
return {"message": f"已初始化 {len(defaults)} 项默认配置"}
|
||||
|
||||
|
||||
class TestConnectionRequest(BaseModel):
|
||||
"""API 连接测试请求"""
|
||||
provider: str # siliconflow / volcengine
|
||||
api_key: str
|
||||
base_url: str
|
||||
|
||||
|
||||
@router.post("/configs/test")
|
||||
async def test_api_connection(
|
||||
data: TestConnectionRequest,
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""测试 AI API 连接是否正常"""
|
||||
try:
|
||||
if data.provider == "siliconflow":
|
||||
url = f"{data.base_url}/models"
|
||||
headers = {"Authorization": f"Bearer {data.api_key}"}
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
return {"message": "连接成功,API Key 有效"}
|
||||
elif resp.status_code == 401:
|
||||
raise HTTPException(status_code=400, detail="API Key 无效,请检查")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"请求失败,状态码: {resp.status_code}")
|
||||
elif data.provider == "volcengine":
|
||||
url = f"{data.base_url}/models"
|
||||
headers = {"Authorization": f"Bearer {data.api_key}"}
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
return {"message": "连接成功,API Key 有效"}
|
||||
elif resp.status_code == 401:
|
||||
raise HTTPException(status_code=400, detail="API Key 无效,请检查")
|
||||
else:
|
||||
# 火山引擎可能返回其他状态码但连接本身成功
|
||||
return {"message": f"连接成功(状态码: {resp.status_code})"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"未知的服务提供商: {data.provider}")
|
||||
except httpx.ConnectError:
|
||||
raise HTTPException(status_code=400, detail="连接失败,请检查接口地址")
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=400, detail="连接超时,请检查网络")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"测试失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 用户管理 ====================
|
||||
|
||||
@router.get("/users", response_model=AdminUserListResponse)
|
||||
def get_users(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
keyword: str = Query(None, description="搜索用户名/昵称"),
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取用户列表"""
|
||||
query = db.query(User)
|
||||
if keyword:
|
||||
query = query.filter(
|
||||
or_(User.username.like(f"%{keyword}%"), User.nickname.like(f"%{keyword}%"))
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
users = query.order_by(User.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
items = []
|
||||
for u in users:
|
||||
design_count = db.query(func.count(Design.id)).filter(Design.user_id == u.id).scalar() or 0
|
||||
items.append(AdminUserResponse(
|
||||
id=u.id, username=u.username, nickname=u.nickname,
|
||||
phone=u.phone, is_admin=u.is_admin,
|
||||
created_at=u.created_at, design_count=design_count
|
||||
))
|
||||
|
||||
return AdminUserListResponse(items=items, total=total, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/admin")
|
||||
def set_user_admin(
|
||||
user_id: int,
|
||||
data: AdminSetAdmin,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""设置/取消管理员"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
if user.id == admin.id:
|
||||
raise HTTPException(status_code=400, detail="不能修改自己的管理员状态")
|
||||
user.is_admin = data.is_admin
|
||||
db.commit()
|
||||
return {"message": f"用户 {user.username} {'已设为管理员' if data.is_admin else '已取消管理员'}"}
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""删除用户"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
if user.id == admin.id:
|
||||
raise HTTPException(status_code=400, detail="不能删除自己")
|
||||
if user.is_admin:
|
||||
raise HTTPException(status_code=400, detail="不能删除其他管理员")
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return {"message": "用户已删除"}
|
||||
|
||||
|
||||
# ==================== 品类管理 ====================
|
||||
|
||||
@router.get("/categories")
|
||||
def get_categories(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取所有品类(含子类型和颜色)"""
|
||||
categories = db.query(Category).order_by(Category.sort_order).all()
|
||||
result = []
|
||||
for cat in categories:
|
||||
result.append({
|
||||
"id": cat.id,
|
||||
"name": cat.name,
|
||||
"icon": cat.icon,
|
||||
"sort_order": cat.sort_order,
|
||||
"flow_type": cat.flow_type,
|
||||
"sub_types": [{"id": st.id, "name": st.name, "description": st.description,
|
||||
"preview_image": st.preview_image, "sort_order": st.sort_order}
|
||||
for st in sorted(cat.sub_types, key=lambda x: x.sort_order)],
|
||||
"colors": [{"id": c.id, "name": c.name, "hex_code": c.hex_code, "sort_order": c.sort_order}
|
||||
for c in sorted(cat.colors, key=lambda x: x.sort_order)]
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/categories")
|
||||
def create_category(
|
||||
data: CategoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""创建品类"""
|
||||
cat = Category(name=data.name, icon=data.icon, sort_order=data.sort_order, flow_type=data.flow_type)
|
||||
db.add(cat)
|
||||
db.commit()
|
||||
db.refresh(cat)
|
||||
return {"id": cat.id, "message": "品类创建成功"}
|
||||
|
||||
|
||||
@router.put("/categories/{cat_id}")
|
||||
def update_category(
|
||||
cat_id: int,
|
||||
data: CategoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""更新品类"""
|
||||
cat = db.query(Category).filter(Category.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="品类不存在")
|
||||
for field, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(cat, field, value)
|
||||
db.commit()
|
||||
return {"message": "品类更新成功"}
|
||||
|
||||
|
||||
@router.delete("/categories/{cat_id}")
|
||||
def delete_category(
|
||||
cat_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""删除品类(级联删除子类型和颜色)"""
|
||||
cat = db.query(Category).filter(Category.id == cat_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="品类不存在")
|
||||
# 检查是否有关联设计
|
||||
design_count = db.query(func.count(Design.id)).filter(Design.category_id == cat_id).scalar()
|
||||
if design_count > 0:
|
||||
raise HTTPException(status_code=400, detail=f"品类下有 {design_count} 个设计,无法删除")
|
||||
# 删除子类型和颜色
|
||||
db.query(SubType).filter(SubType.category_id == cat_id).delete()
|
||||
db.query(Color).filter(Color.category_id == cat_id).delete()
|
||||
db.delete(cat)
|
||||
db.commit()
|
||||
return {"message": "品类已删除"}
|
||||
|
||||
|
||||
# -- 子类型 --
|
||||
@router.post("/sub-types")
|
||||
def create_sub_type(
|
||||
data: SubTypeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""创建子类型"""
|
||||
cat = db.query(Category).filter(Category.id == data.category_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="品类不存在")
|
||||
st = SubType(category_id=data.category_id, name=data.name,
|
||||
description=data.description, preview_image=data.preview_image,
|
||||
sort_order=data.sort_order)
|
||||
db.add(st)
|
||||
db.commit()
|
||||
db.refresh(st)
|
||||
return {"id": st.id, "message": "子类型创建成功"}
|
||||
|
||||
|
||||
@router.put("/sub-types/{st_id}")
|
||||
def update_sub_type(
|
||||
st_id: int,
|
||||
data: SubTypeUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""更新子类型"""
|
||||
st = db.query(SubType).filter(SubType.id == st_id).first()
|
||||
if not st:
|
||||
raise HTTPException(status_code=404, detail="子类型不存在")
|
||||
for field, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(st, field, value)
|
||||
db.commit()
|
||||
return {"message": "子类型更新成功"}
|
||||
|
||||
|
||||
@router.delete("/sub-types/{st_id}")
|
||||
def delete_sub_type(
|
||||
st_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""删除子类型"""
|
||||
st = db.query(SubType).filter(SubType.id == st_id).first()
|
||||
if not st:
|
||||
raise HTTPException(status_code=404, detail="子类型不存在")
|
||||
db.delete(st)
|
||||
db.commit()
|
||||
return {"message": "子类型已删除"}
|
||||
|
||||
|
||||
# -- 颜色 --
|
||||
@router.post("/colors")
|
||||
def create_color(
|
||||
data: ColorCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""创建颜色"""
|
||||
cat = db.query(Category).filter(Category.id == data.category_id).first()
|
||||
if not cat:
|
||||
raise HTTPException(status_code=404, detail="品类不存在")
|
||||
color = Color(category_id=data.category_id, name=data.name,
|
||||
hex_code=data.hex_code, sort_order=data.sort_order)
|
||||
db.add(color)
|
||||
db.commit()
|
||||
db.refresh(color)
|
||||
return {"id": color.id, "message": "颜色创建成功"}
|
||||
|
||||
|
||||
@router.put("/colors/{color_id}")
|
||||
def update_color(
|
||||
color_id: int,
|
||||
data: ColorUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""更新颜色"""
|
||||
color = db.query(Color).filter(Color.id == color_id).first()
|
||||
if not color:
|
||||
raise HTTPException(status_code=404, detail="颜色不存在")
|
||||
for field, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(color, field, value)
|
||||
db.commit()
|
||||
return {"message": "颜色更新成功"}
|
||||
|
||||
|
||||
@router.delete("/colors/{color_id}")
|
||||
def delete_color(
|
||||
color_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""删除颜色"""
|
||||
color = db.query(Color).filter(Color.id == color_id).first()
|
||||
if not color:
|
||||
raise HTTPException(status_code=404, detail="颜色不存在")
|
||||
db.delete(color)
|
||||
db.commit()
|
||||
return {"message": "颜色已删除"}
|
||||
|
||||
|
||||
# ==================== 设计管理 ====================
|
||||
|
||||
@router.get("/designs", response_model=AdminDesignListResponse)
|
||||
def get_all_designs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
user_id: int = Query(None, description="按用户筛选"),
|
||||
status_filter: str = Query(None, alias="status", description="按状态筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取所有设计列表"""
|
||||
query = db.query(Design)
|
||||
if user_id:
|
||||
query = query.filter(Design.user_id == user_id)
|
||||
if status_filter:
|
||||
query = query.filter(Design.status == status_filter)
|
||||
|
||||
total = query.count()
|
||||
designs = query.order_by(Design.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
items = []
|
||||
for d in designs:
|
||||
items.append({
|
||||
"id": d.id,
|
||||
"user_id": d.user_id,
|
||||
"username": d.user.username if d.user else None,
|
||||
"category_name": d.category.name if d.category else None,
|
||||
"sub_type_name": d.sub_type.name if d.sub_type else None,
|
||||
"color_name": d.color.name if d.color else None,
|
||||
"prompt": d.prompt,
|
||||
"image_url": d.image_url,
|
||||
"status": d.status,
|
||||
"created_at": d.created_at.isoformat() if d.created_at else None,
|
||||
})
|
||||
|
||||
return AdminDesignListResponse(items=items, total=total, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.delete("/designs/{design_id}")
|
||||
def admin_delete_design(
|
||||
design_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""管理员删除任意设计"""
|
||||
design = db.query(Design).filter(Design.id == design_id).first()
|
||||
if not design:
|
||||
raise HTTPException(status_code=404, detail="设计不存在")
|
||||
db.delete(design)
|
||||
db.commit()
|
||||
return {"message": "设计已删除"}
|
||||
|
||||
|
||||
# ==================== 提示词管理 ====================
|
||||
|
||||
@router.get("/prompt-templates")
|
||||
def get_prompt_templates(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取所有提示词模板"""
|
||||
templates = db.query(PromptTemplate).order_by(PromptTemplate.template_key).all()
|
||||
return [PromptTemplateItem.model_validate(t) for t in templates]
|
||||
|
||||
|
||||
@router.put("/prompt-templates/{template_id}")
|
||||
def update_prompt_template(
|
||||
template_id: int,
|
||||
data: PromptTemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""更新提示词模板"""
|
||||
tpl = db.query(PromptTemplate).filter(PromptTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
tpl.template_value = data.template_value
|
||||
if data.description is not None:
|
||||
tpl.description = data.description
|
||||
db.commit()
|
||||
return {"message": f"模板 '{tpl.template_key}' 更新成功"}
|
||||
|
||||
|
||||
@router.get("/prompt-mappings")
|
||||
def get_prompt_mappings(
|
||||
mapping_type: str = Query(None, description="按映射类型筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取提示词映射列表"""
|
||||
query = db.query(PromptMapping)
|
||||
if mapping_type:
|
||||
query = query.filter(PromptMapping.mapping_type == mapping_type)
|
||||
mappings = query.order_by(PromptMapping.mapping_type, PromptMapping.sort_order).all()
|
||||
return [PromptMappingItem.model_validate(m) for m in mappings]
|
||||
|
||||
|
||||
@router.get("/prompt-mappings/types")
|
||||
def get_mapping_types(
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""获取所有映射类型及其数量"""
|
||||
from sqlalchemy import distinct
|
||||
types = db.query(
|
||||
PromptMapping.mapping_type,
|
||||
func.count(PromptMapping.id)
|
||||
).group_by(PromptMapping.mapping_type).all()
|
||||
return [{"type": t, "count": c, "label": {
|
||||
"category": "品类", "color": "颜色", "view": "视角",
|
||||
"carving": "雕刻工艺", "style": "设计风格", "motif": "题材纹样",
|
||||
"finish": "表面处理", "scene": "用途场景", "sub_type": "子类型"
|
||||
}.get(t, t)} for t, c in types]
|
||||
|
||||
|
||||
@router.post("/prompt-mappings")
|
||||
def create_prompt_mapping(
|
||||
data: PromptMappingCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""创建提示词映射"""
|
||||
# 检查重复
|
||||
existing = db.query(PromptMapping).filter(
|
||||
PromptMapping.mapping_type == data.mapping_type,
|
||||
PromptMapping.cn_key == data.cn_key
|
||||
).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"映射 '{data.cn_key}' 已存在")
|
||||
mapping = PromptMapping(
|
||||
mapping_type=data.mapping_type,
|
||||
cn_key=data.cn_key,
|
||||
en_value=data.en_value,
|
||||
sort_order=data.sort_order
|
||||
)
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
return {"id": mapping.id, "message": "映射创建成功"}
|
||||
|
||||
|
||||
@router.put("/prompt-mappings/{mapping_id}")
|
||||
def update_prompt_mapping(
|
||||
mapping_id: int,
|
||||
data: PromptMappingUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""更新提示词映射"""
|
||||
mapping = db.query(PromptMapping).filter(PromptMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
for field, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(mapping, field, value)
|
||||
db.commit()
|
||||
return {"message": "映射更新成功"}
|
||||
|
||||
|
||||
@router.delete("/prompt-mappings/{mapping_id}")
|
||||
def delete_prompt_mapping(
|
||||
mapping_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""删除提示词映射"""
|
||||
mapping = db.query(PromptMapping).filter(PromptMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
return {"message": "映射已删除"}
|
||||
|
||||
|
||||
@router.post("/prompt-preview")
|
||||
def preview_prompt(
|
||||
params: dict,
|
||||
db: Session = Depends(get_db),
|
||||
admin: User = Depends(get_admin_user)
|
||||
):
|
||||
"""预览提示词生成结果"""
|
||||
from ..services.prompt_builder import build_prompt
|
||||
try:
|
||||
prompt = build_prompt(
|
||||
category_name=params.get("category_name", "牌子"),
|
||||
view_name=params.get("view_name", "效果图"),
|
||||
sub_type_name=params.get("sub_type_name"),
|
||||
color_name=params.get("color_name"),
|
||||
user_prompt=params.get("user_prompt"),
|
||||
carving_technique=params.get("carving_technique"),
|
||||
design_style=params.get("design_style"),
|
||||
motif=params.get("motif"),
|
||||
size_spec=params.get("size_spec"),
|
||||
surface_finish=params.get("surface_finish"),
|
||||
usage_scene=params.get("usage_scene"),
|
||||
)
|
||||
return {"prompt": prompt}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"提示词生成失败: {str(e)}")
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from ..database import get_db
|
||||
from ..models import User, Design
|
||||
from ..schemas import DesignCreate, DesignResponse, DesignListResponse
|
||||
from ..schemas import DesignCreate, DesignResponse, DesignListResponse, DesignImageResponse
|
||||
from ..utils.deps import get_current_user
|
||||
from ..services import design_service
|
||||
|
||||
@@ -18,6 +18,21 @@ router = APIRouter(prefix="/api/designs", tags=["设计"])
|
||||
|
||||
def design_to_response(design: Design) -> DesignResponse:
|
||||
"""将 Design 模型转换为响应格式"""
|
||||
# 构建多视角图片列表
|
||||
images = []
|
||||
if hasattr(design, 'images') and design.images:
|
||||
images = [
|
||||
DesignImageResponse(
|
||||
id=img.id,
|
||||
view_name=img.view_name,
|
||||
image_url=img.image_url,
|
||||
model_used=img.model_used,
|
||||
prompt_used=img.prompt_used,
|
||||
sort_order=img.sort_order,
|
||||
)
|
||||
for img in design.images
|
||||
]
|
||||
|
||||
return DesignResponse(
|
||||
id=design.id,
|
||||
user_id=design.user_id,
|
||||
@@ -51,6 +66,7 @@ def design_to_response(design: Design) -> DesignResponse:
|
||||
surface_finish=design.surface_finish,
|
||||
usage_scene=design.usage_scene,
|
||||
image_url=design.image_url,
|
||||
images=images,
|
||||
status=design.status,
|
||||
created_at=design.created_at,
|
||||
updated_at=design.updated_at
|
||||
@@ -58,17 +74,17 @@ def design_to_response(design: Design) -> DesignResponse:
|
||||
|
||||
|
||||
@router.post("/generate", response_model=DesignResponse)
|
||||
def generate_design(
|
||||
async def generate_design(
|
||||
design_data: DesignCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
提交设计生成请求
|
||||
提交设计生成请求(异步,支持 AI 多视角生图)
|
||||
需要认证
|
||||
"""
|
||||
try:
|
||||
design = design_service.create_design(
|
||||
design = await design_service.create_design_async(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
design_data=design_data
|
||||
|
||||
@@ -4,7 +4,7 @@ Pydantic Schemas
|
||||
"""
|
||||
from .user import UserCreate, UserLogin, UserResponse, Token, UserUpdate, PasswordChange
|
||||
from .category import CategoryResponse, SubTypeResponse, ColorResponse
|
||||
from .design import DesignCreate, DesignResponse, DesignListResponse
|
||||
from .design import DesignCreate, DesignResponse, DesignListResponse, DesignImageResponse
|
||||
|
||||
__all__ = [
|
||||
# User schemas
|
||||
@@ -22,4 +22,5 @@ __all__ = [
|
||||
"DesignCreate",
|
||||
"DesignResponse",
|
||||
"DesignListResponse",
|
||||
"DesignImageResponse",
|
||||
]
|
||||
|
||||
173
backend/app/schemas/admin.py
Normal file
173
backend/app/schemas/admin.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
管理后台相关 Pydantic Schemas
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
# ============ 系统配置 ============
|
||||
class SystemConfigItem(BaseModel):
|
||||
"""单个配置项"""
|
||||
config_key: str
|
||||
config_value: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
config_group: str = "general"
|
||||
is_secret: str = "N"
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SystemConfigUpdate(BaseModel):
|
||||
"""批量更新配置"""
|
||||
configs: dict = Field(..., description="键值对: {config_key: config_value}")
|
||||
|
||||
|
||||
class SystemConfigResponse(BaseModel):
|
||||
"""配置列表响应"""
|
||||
items: List[SystemConfigItem]
|
||||
|
||||
|
||||
# ============ 用户管理 ============
|
||||
class AdminUserResponse(BaseModel):
|
||||
"""管理端用户信息"""
|
||||
id: int
|
||||
username: str
|
||||
nickname: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
is_admin: bool = False
|
||||
created_at: datetime
|
||||
design_count: int = 0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AdminUserListResponse(BaseModel):
|
||||
"""用户列表响应"""
|
||||
items: List[AdminUserResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class AdminSetAdmin(BaseModel):
|
||||
"""设置管理员"""
|
||||
is_admin: bool
|
||||
|
||||
|
||||
# ============ 品类管理 ============
|
||||
class CategoryCreate(BaseModel):
|
||||
"""创建品类"""
|
||||
name: str = Field(..., max_length=50)
|
||||
icon: Optional[str] = Field(None, max_length=255)
|
||||
sort_order: int = 0
|
||||
flow_type: str = Field("full", max_length=20)
|
||||
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
"""更新品类"""
|
||||
name: Optional[str] = Field(None, max_length=50)
|
||||
icon: Optional[str] = Field(None, max_length=255)
|
||||
sort_order: Optional[int] = None
|
||||
flow_type: Optional[str] = Field(None, max_length=20)
|
||||
|
||||
|
||||
class SubTypeCreate(BaseModel):
|
||||
"""创建子类型"""
|
||||
category_id: int
|
||||
name: str = Field(..., max_length=50)
|
||||
description: Optional[str] = Field(None, max_length=255)
|
||||
preview_image: Optional[str] = Field(None, max_length=255)
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class SubTypeUpdate(BaseModel):
|
||||
"""更新子类型"""
|
||||
name: Optional[str] = Field(None, max_length=50)
|
||||
description: Optional[str] = Field(None, max_length=255)
|
||||
preview_image: Optional[str] = Field(None, max_length=255)
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
class ColorCreate(BaseModel):
|
||||
"""创建颜色"""
|
||||
category_id: int
|
||||
name: str = Field(..., max_length=50)
|
||||
hex_code: Optional[str] = Field(None, max_length=10)
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class ColorUpdate(BaseModel):
|
||||
"""更新颜色"""
|
||||
name: Optional[str] = Field(None, max_length=50)
|
||||
hex_code: Optional[str] = Field(None, max_length=10)
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
# ============ 设计管理 ============
|
||||
class AdminDesignListResponse(BaseModel):
|
||||
"""管理端设计列表"""
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
# ============ 提示词管理 ============
|
||||
class PromptTemplateItem(BaseModel):
|
||||
"""提示词模板"""
|
||||
id: Optional[int] = None
|
||||
template_key: str
|
||||
template_value: str
|
||||
description: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PromptTemplateUpdate(BaseModel):
|
||||
"""更新提示词模板"""
|
||||
template_value: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class PromptMappingItem(BaseModel):
|
||||
"""提示词映射"""
|
||||
id: Optional[int] = None
|
||||
mapping_type: str
|
||||
cn_key: str
|
||||
en_value: str
|
||||
sort_order: int = 0
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PromptMappingCreate(BaseModel):
|
||||
"""创建提示词映射"""
|
||||
mapping_type: str
|
||||
cn_key: str
|
||||
en_value: str
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class PromptMappingUpdate(BaseModel):
|
||||
"""更新提示词映射"""
|
||||
cn_key: Optional[str] = None
|
||||
en_value: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
# ============ 仪表盘 ============
|
||||
class DashboardStats(BaseModel):
|
||||
"""仪表盘统计"""
|
||||
total_users: int
|
||||
total_designs: int
|
||||
total_categories: int
|
||||
today_designs: int
|
||||
today_users: int
|
||||
@@ -8,6 +8,20 @@ from typing import Optional, List
|
||||
from .category import CategoryResponse, SubTypeResponse, ColorResponse
|
||||
|
||||
|
||||
class DesignImageResponse(BaseModel):
|
||||
"""设计图片响应(单张视角图)"""
|
||||
id: int
|
||||
view_name: str
|
||||
image_url: Optional[str] = None
|
||||
model_used: Optional[str] = None
|
||||
prompt_used: Optional[str] = None
|
||||
sort_order: int = 0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class DesignCreate(BaseModel):
|
||||
"""创建设计请求"""
|
||||
category_id: int = Field(..., description="品类ID")
|
||||
@@ -37,6 +51,7 @@ class DesignResponse(BaseModel):
|
||||
surface_finish: Optional[str] = None
|
||||
usage_scene: Optional[str] = None
|
||||
image_url: Optional[str] = None
|
||||
images: List[DesignImageResponse] = []
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@@ -26,6 +26,7 @@ class UserResponse(BaseModel):
|
||||
nickname: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
avatar: Optional[str] = None
|
||||
is_admin: bool = False
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
|
||||
144
backend/app/services/ai_generator.py
Normal file
144
backend/app/services/ai_generator.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
AI 生图服务
|
||||
支持双模型:SiliconFlow FLUX.1 [dev] 和 火山引擎 Seedream 4.5
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
from ..config import settings
|
||||
from .config_service import get_ai_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 超时设置(秒)
|
||||
REQUEST_TIMEOUT = 90
|
||||
# 最大重试次数
|
||||
MAX_RETRIES = 3
|
||||
|
||||
|
||||
async def _call_siliconflow(prompt: str, size: int = 1024, ai_config: dict = None) -> str:
|
||||
"""
|
||||
调用 SiliconFlow FLUX.1 [dev] 生图 API
|
||||
|
||||
Returns:
|
||||
远程图片 URL
|
||||
"""
|
||||
cfg = ai_config or get_ai_config()
|
||||
url = f"{cfg['SILICONFLOW_BASE_URL']}/images/generations"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {cfg['SILICONFLOW_API_KEY']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": "black-forest-labs/FLUX.1-dev",
|
||||
"prompt": prompt,
|
||||
"image_size": f"{size}x{size}",
|
||||
"num_inference_steps": 20,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# SiliconFlow 响应格式: {"images": [{"url": "https://..."}]}
|
||||
images = data.get("images", [])
|
||||
if not images:
|
||||
raise ValueError("SiliconFlow 返回空图片列表")
|
||||
return images[0]["url"]
|
||||
|
||||
|
||||
async def _call_seedream(prompt: str, size: int = 1024, ai_config: dict = None) -> str:
|
||||
"""
|
||||
调用火山引擎 Seedream 4.5 生图 API
|
||||
|
||||
Returns:
|
||||
远程图片 URL
|
||||
"""
|
||||
cfg = ai_config or get_ai_config()
|
||||
url = f"{cfg['VOLCENGINE_BASE_URL']}/images/generations"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {cfg['VOLCENGINE_API_KEY']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": "doubao-seedream-4.5-t2i-250528",
|
||||
"prompt": prompt,
|
||||
"size": f"{size}x{size}",
|
||||
"response_format": "url",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Seedream 响应格式: {"data": [{"url": "https://..."}]}
|
||||
items = data.get("data", [])
|
||||
if not items:
|
||||
raise ValueError("Seedream 返回空图片列表")
|
||||
return items[0]["url"]
|
||||
|
||||
|
||||
async def generate_image(prompt: str, model: Optional[str] = None) -> str:
|
||||
"""
|
||||
统一生图接口,带重试机制
|
||||
|
||||
Args:
|
||||
prompt: 英文提示词
|
||||
model: 模型名称 (flux-dev / seedream-4.5),为空则使用配置默认值
|
||||
|
||||
Returns:
|
||||
远程图片 URL
|
||||
|
||||
Raises:
|
||||
Exception: 所有重试失败后抛出
|
||||
"""
|
||||
ai_config = get_ai_config()
|
||||
model = model or ai_config.get("AI_IMAGE_MODEL", "flux-dev")
|
||||
size = ai_config.get("AI_IMAGE_SIZE", 1024)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
if model == "seedream-4.5":
|
||||
image_url = await _call_seedream(prompt, size, ai_config)
|
||||
else:
|
||||
image_url = await _call_siliconflow(prompt, size, ai_config)
|
||||
logger.info(f"AI 生图成功 (model={model}, attempt={attempt})")
|
||||
return image_url
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"AI 生图失败 (model={model}, attempt={attempt}/{MAX_RETRIES}): {e}")
|
||||
if attempt < MAX_RETRIES:
|
||||
import asyncio
|
||||
await asyncio.sleep(2 * attempt) # 指数退避
|
||||
|
||||
raise RuntimeError(f"AI 生图在 {MAX_RETRIES} 次重试后仍然失败: {last_error}")
|
||||
|
||||
|
||||
async def download_and_save(image_url: str, save_path: str) -> str:
|
||||
"""
|
||||
下载远程图片并保存到本地
|
||||
|
||||
Args:
|
||||
image_url: 远程图片 URL
|
||||
save_path: 本地保存路径(如 uploads/designs/1001_效果图.png)
|
||||
|
||||
Returns:
|
||||
本地文件相对路径(以 / 开头,如 /uploads/designs/1001_效果图.png)
|
||||
"""
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
logger.info(f"图片已下载保存: {save_path}")
|
||||
return f"/{save_path}"
|
||||
58
backend/app/services/config_service.py
Normal file
58
backend/app/services/config_service.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
配置服务
|
||||
优先从数据库 system_configs 表读取配置,数据库无值时回退到 .env
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..database import SessionLocal
|
||||
from ..models.system_config import SystemConfig
|
||||
from ..config import settings as env_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_config_value(key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
获取配置值(数据库优先,.env 兜底)
|
||||
|
||||
Args:
|
||||
key: 配置键名(如 SILICONFLOW_API_KEY)
|
||||
default: 默认值
|
||||
Returns:
|
||||
配置值字符串
|
||||
"""
|
||||
# 1. 尝试从数据库读取
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
config = db.query(SystemConfig).filter(SystemConfig.config_key == key).first()
|
||||
if config and config.config_value:
|
||||
return config.config_value
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"从数据库读取配置 {key} 失败: {e}")
|
||||
|
||||
# 2. 回退到 .env / Settings
|
||||
env_value = getattr(env_settings, key, None)
|
||||
if env_value is not None and env_value != "":
|
||||
return str(env_value)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_ai_config() -> dict:
|
||||
"""
|
||||
获取所有 AI 相关配置
|
||||
返回字典,方便 ai_generator 使用
|
||||
"""
|
||||
return {
|
||||
"SILICONFLOW_API_KEY": get_config_value("SILICONFLOW_API_KEY", ""),
|
||||
"SILICONFLOW_BASE_URL": get_config_value("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1"),
|
||||
"VOLCENGINE_API_KEY": get_config_value("VOLCENGINE_API_KEY", ""),
|
||||
"VOLCENGINE_BASE_URL": get_config_value("VOLCENGINE_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"),
|
||||
"AI_IMAGE_MODEL": get_config_value("AI_IMAGE_MODEL", "flux-dev"),
|
||||
"AI_IMAGE_SIZE": int(get_config_value("AI_IMAGE_SIZE", "1024")),
|
||||
}
|
||||
@@ -1,40 +1,56 @@
|
||||
"""
|
||||
设计服务
|
||||
处理设计相关的业务逻辑
|
||||
处理设计相关的业务逻辑,支持 AI 多视角生图 + mock 降级
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from ..models import Design, Category, SubType, Color
|
||||
from ..models import Design, DesignImage, Category, SubType, Color
|
||||
from ..schemas import DesignCreate
|
||||
from ..config import settings
|
||||
from .mock_generator import generate_mock_design
|
||||
from .prompt_builder import get_views_for_category, build_prompt
|
||||
from . import ai_generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design:
|
||||
def _has_ai_key() -> bool:
|
||||
"""检查是否配置了 AI API Key"""
|
||||
model = settings.AI_IMAGE_MODEL
|
||||
if model == "seedream-4.5":
|
||||
return bool(settings.VOLCENGINE_API_KEY)
|
||||
return bool(settings.SILICONFLOW_API_KEY)
|
||||
|
||||
|
||||
async def create_design_async(db: Session, user_id: int, design_data: DesignCreate) -> Design:
|
||||
"""
|
||||
创建设计记录
|
||||
|
||||
1. 创建设计记录(status=generating)
|
||||
2. 调用 mock_generator 生成图片
|
||||
3. 更新设计记录(status=completed, image_url)
|
||||
4. 返回设计对象
|
||||
创建设计记录(异步版本,支持 AI 多视角生图)
|
||||
|
||||
流程:
|
||||
1. 创建 Design 记录(status=generating)
|
||||
2. 获取品类视角列表
|
||||
3. 循环每个视角:构建 prompt → 调用 AI 生图 → 下载保存 → 创建 DesignImage
|
||||
4. 第一张效果图 URL 存入 design.image_url(兼容旧逻辑)
|
||||
5. 更新 status=completed
|
||||
6. 失败时降级到 mock_generator
|
||||
"""
|
||||
# 获取关联信息
|
||||
category = db.query(Category).filter(Category.id == design_data.category_id).first()
|
||||
if not category:
|
||||
raise ValueError(f"品类不存在: {design_data.category_id}")
|
||||
|
||||
|
||||
sub_type = None
|
||||
if design_data.sub_type_id:
|
||||
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
|
||||
|
||||
|
||||
color = None
|
||||
if design_data.color_id:
|
||||
color = db.query(Color).filter(Color.id == design_data.color_id).first()
|
||||
|
||||
|
||||
# 创建设计记录
|
||||
design = Design(
|
||||
user_id=user_id,
|
||||
@@ -52,8 +68,109 @@ def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Desig
|
||||
)
|
||||
db.add(design)
|
||||
db.flush() # 获取 ID
|
||||
|
||||
# 生成图片
|
||||
|
||||
# 尝试 AI 生图
|
||||
if _has_ai_key():
|
||||
try:
|
||||
await _generate_ai_images(db, design, category, sub_type, color, design_data)
|
||||
db.commit()
|
||||
db.refresh(design)
|
||||
return design
|
||||
except Exception as e:
|
||||
logger.error(f"AI 生图全部失败,降级到 mock: {e}")
|
||||
db.rollback()
|
||||
# 重新查询,因为 rollback 后 ORM 对象可能失效
|
||||
design = db.query(Design).filter(Design.id == design.id).first()
|
||||
if not design:
|
||||
# rollback 导致 design 也没了,重新创建
|
||||
design = Design(
|
||||
user_id=user_id,
|
||||
category_id=design_data.category_id,
|
||||
sub_type_id=design_data.sub_type_id,
|
||||
color_id=design_data.color_id,
|
||||
prompt=design_data.prompt,
|
||||
carving_technique=design_data.carving_technique,
|
||||
design_style=design_data.design_style,
|
||||
motif=design_data.motif,
|
||||
size_spec=design_data.size_spec,
|
||||
surface_finish=design_data.surface_finish,
|
||||
usage_scene=design_data.usage_scene,
|
||||
status="generating"
|
||||
)
|
||||
db.add(design)
|
||||
db.flush()
|
||||
|
||||
# 降级到 mock 生成
|
||||
_generate_mock_fallback(db, design, category, sub_type, color, design_data)
|
||||
db.commit()
|
||||
db.refresh(design)
|
||||
return design
|
||||
|
||||
|
||||
async def _generate_ai_images(
|
||||
db: Session,
|
||||
design: Design,
|
||||
category,
|
||||
sub_type,
|
||||
color,
|
||||
design_data: DesignCreate,
|
||||
) -> None:
|
||||
"""使用 AI 模型为每个视角生成图片"""
|
||||
views = get_views_for_category(category.name)
|
||||
model = settings.AI_IMAGE_MODEL
|
||||
|
||||
for idx, view_name in enumerate(views):
|
||||
# 构建 prompt
|
||||
prompt_text = build_prompt(
|
||||
category_name=category.name,
|
||||
view_name=view_name,
|
||||
sub_type_name=sub_type.name if sub_type else None,
|
||||
color_name=color.name if color else None,
|
||||
user_prompt=design_data.prompt,
|
||||
carving_technique=design_data.carving_technique,
|
||||
design_style=design_data.design_style,
|
||||
motif=design_data.motif,
|
||||
size_spec=design_data.size_spec,
|
||||
surface_finish=design_data.surface_finish,
|
||||
usage_scene=design_data.usage_scene,
|
||||
)
|
||||
|
||||
# 调用 AI 生图
|
||||
remote_url = await ai_generator.generate_image(prompt_text, model)
|
||||
|
||||
# 下载保存到本地
|
||||
save_path = os.path.join(
|
||||
settings.UPLOAD_DIR, "designs", f"{design.id}_{view_name}.png"
|
||||
)
|
||||
local_url = await ai_generator.download_and_save(remote_url, save_path)
|
||||
|
||||
# 创建 DesignImage 记录
|
||||
design_image = DesignImage(
|
||||
design_id=design.id,
|
||||
view_name=view_name,
|
||||
image_url=local_url,
|
||||
model_used=model,
|
||||
prompt_used=prompt_text,
|
||||
sort_order=idx,
|
||||
)
|
||||
db.add(design_image)
|
||||
|
||||
# 第一张图(效果图)存入 design.image_url 兼容旧逻辑
|
||||
if idx == 0:
|
||||
design.image_url = local_url
|
||||
|
||||
design.status = "completed"
|
||||
|
||||
|
||||
def _generate_mock_fallback(
|
||||
db: Session,
|
||||
design: Design,
|
||||
category,
|
||||
sub_type,
|
||||
color,
|
||||
design_data: DesignCreate,
|
||||
) -> None:
|
||||
"""降级使用 mock 生成器"""
|
||||
save_path = os.path.join(settings.UPLOAD_DIR, "designs", f"{design.id}.png")
|
||||
image_url = generate_mock_design(
|
||||
category_name=category.name,
|
||||
@@ -68,13 +185,47 @@ def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Desig
|
||||
surface_finish=design_data.surface_finish,
|
||||
usage_scene=design_data.usage_scene,
|
||||
)
|
||||
|
||||
# 更新设计记录
|
||||
design.image_url = image_url
|
||||
design.status = "completed"
|
||||
logger.info(f"Mock 降级生成完成: design_id={design.id}")
|
||||
|
||||
|
||||
def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design:
|
||||
"""
|
||||
同步版本创建设计(兼容旧调用,仅用 mock)
|
||||
"""
|
||||
category = db.query(Category).filter(Category.id == design_data.category_id).first()
|
||||
if not category:
|
||||
raise ValueError(f"品类不存在: {design_data.category_id}")
|
||||
|
||||
sub_type = None
|
||||
if design_data.sub_type_id:
|
||||
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
|
||||
|
||||
color = None
|
||||
if design_data.color_id:
|
||||
color = db.query(Color).filter(Color.id == design_data.color_id).first()
|
||||
|
||||
design = Design(
|
||||
user_id=user_id,
|
||||
category_id=design_data.category_id,
|
||||
sub_type_id=design_data.sub_type_id,
|
||||
color_id=design_data.color_id,
|
||||
prompt=design_data.prompt,
|
||||
carving_technique=design_data.carving_technique,
|
||||
design_style=design_data.design_style,
|
||||
motif=design_data.motif,
|
||||
size_spec=design_data.size_spec,
|
||||
surface_finish=design_data.surface_finish,
|
||||
usage_scene=design_data.usage_scene,
|
||||
status="generating"
|
||||
)
|
||||
db.add(design)
|
||||
db.flush()
|
||||
|
||||
_generate_mock_fallback(db, design, category, sub_type, color, design_data)
|
||||
db.commit()
|
||||
db.refresh(design)
|
||||
|
||||
return design
|
||||
|
||||
|
||||
@@ -132,16 +283,24 @@ def delete_design(db: Session, design_id: int, user_id: int) -> bool:
|
||||
if not design:
|
||||
return False
|
||||
|
||||
# 删除图片文件
|
||||
# 删除主图片文件
|
||||
if design.image_url:
|
||||
# image_url 格式: /uploads/designs/1001.png
|
||||
# 转换为实际文件路径
|
||||
file_path = design.image_url.lstrip("/")
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass # 忽略删除失败
|
||||
pass
|
||||
|
||||
# 删除多视角图片文件
|
||||
for img in design.images:
|
||||
if img.image_url:
|
||||
fp = img.image_url.lstrip("/")
|
||||
if os.path.exists(fp):
|
||||
try:
|
||||
os.remove(fp)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 删除数据库记录
|
||||
db.delete(design)
|
||||
|
||||
164
backend/app/services/prompt_builder.py
Normal file
164
backend/app/services/prompt_builder.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
专业玉雕设计提示词构建器(数据库版)
|
||||
从数据库 prompt_templates + prompt_mappings 读取配置,支持后台热更新
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from ..database import SessionLocal
|
||||
from ..models.prompt_template import PromptTemplate, PromptMapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============================================================
|
||||
# 品类视角配置(保留硬编码,因为与业务流程强关联)
|
||||
# ============================================================
|
||||
CATEGORY_VIEWS: Dict[str, List[str]] = {
|
||||
"牌子": ["效果图", "正面图", "背面图"],
|
||||
"珠子": ["效果图", "正面图"],
|
||||
"手把件": ["效果图", "正面图", "侧面图", "背面图"],
|
||||
"雕刻件": ["效果图", "正面图", "侧面图", "背面图"],
|
||||
"摆件": ["效果图", "正面图", "侧面图", "背面图"],
|
||||
"手镯": ["效果图", "正面图", "侧面图"],
|
||||
"耳钉": ["效果图", "正面图"],
|
||||
"耳饰": ["效果图", "正面图"],
|
||||
"手链": ["效果图", "正面图"],
|
||||
"项链": ["效果图", "正面图"],
|
||||
"戒指": ["效果图", "正面图", "侧面图"],
|
||||
"表带": ["效果图", "正面图"],
|
||||
}
|
||||
|
||||
|
||||
def _load_mappings(mapping_type: str) -> Dict[str, str]:
|
||||
"""从数据库加载指定类型的映射字典"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = db.query(PromptMapping).filter(
|
||||
PromptMapping.mapping_type == mapping_type
|
||||
).order_by(PromptMapping.sort_order).all()
|
||||
return {r.cn_key: r.en_value for r in rows}
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载映射 {mapping_type} 失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _load_template(template_key: str, default: str = "") -> str:
|
||||
"""从数据库加载模板"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
tpl = db.query(PromptTemplate).filter(
|
||||
PromptTemplate.template_key == template_key
|
||||
).first()
|
||||
if tpl:
|
||||
return tpl.template_value
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载模板 {template_key} 失败: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def get_views_for_category(category_name: str) -> List[str]:
|
||||
"""获取品类对应的视角列表"""
|
||||
return CATEGORY_VIEWS.get(category_name, ["效果图", "正面图"])
|
||||
|
||||
|
||||
def build_prompt(
|
||||
category_name: str,
|
||||
view_name: str,
|
||||
sub_type_name: Optional[str] = None,
|
||||
color_name: Optional[str] = None,
|
||||
user_prompt: Optional[str] = None,
|
||||
carving_technique: Optional[str] = None,
|
||||
design_style: Optional[str] = None,
|
||||
motif: Optional[str] = None,
|
||||
size_spec: Optional[str] = None,
|
||||
surface_finish: Optional[str] = None,
|
||||
usage_scene: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
构建专业英文生图提示词(从数据库读取映射和模板)
|
||||
|
||||
业务逻辑:用户参数 → 中英映射 → 填入模板 → 最终prompt
|
||||
"""
|
||||
# 从数据库加载所有映射
|
||||
category_map = _load_mappings("category")
|
||||
color_map = _load_mappings("color")
|
||||
view_map = _load_mappings("view")
|
||||
carving_map = _load_mappings("carving")
|
||||
style_map = _load_mappings("style")
|
||||
motif_map = _load_mappings("motif")
|
||||
finish_map = _load_mappings("finish")
|
||||
scene_map = _load_mappings("scene")
|
||||
sub_type_map = _load_mappings("sub_type")
|
||||
|
||||
# 加载模板
|
||||
quality_suffix = _load_template("quality_suffix",
|
||||
"professional jewelry product photography, studio lighting setup, pure white background, ultra-detailed, sharp focus, 8K resolution, photorealistic rendering, high-end commercial quality")
|
||||
default_color = _load_template("default_color",
|
||||
"natural Hetian nephrite jade with warm luster")
|
||||
|
||||
# 构建各部分
|
||||
parts = []
|
||||
|
||||
# 1. 品类主体
|
||||
subject = category_map.get(category_name, f"Chinese Hetian nephrite jade {category_name}")
|
||||
parts.append(subject)
|
||||
|
||||
# 2. 子类型
|
||||
if sub_type_name:
|
||||
sub_detail = sub_type_map.get(sub_type_name, sub_type_name)
|
||||
parts.append(sub_detail)
|
||||
|
||||
# 3. 颜色
|
||||
if color_name:
|
||||
color_desc = color_map.get(color_name, f"{color_name} colored nephrite jade")
|
||||
parts.append(color_desc)
|
||||
else:
|
||||
parts.append(default_color)
|
||||
|
||||
# 4. 题材纹样
|
||||
if motif:
|
||||
motif_desc = motif_map.get(motif, f"{motif} themed design")
|
||||
parts.append(f"featuring {motif_desc}")
|
||||
|
||||
# 5. 雕刻工艺
|
||||
if carving_technique:
|
||||
tech_desc = carving_map.get(carving_technique, carving_technique)
|
||||
parts.append(tech_desc)
|
||||
|
||||
# 6. 设计风格
|
||||
if design_style:
|
||||
style_desc = style_map.get(design_style, design_style)
|
||||
parts.append(style_desc)
|
||||
|
||||
# 7. 表面处理
|
||||
if surface_finish:
|
||||
finish_desc = finish_map.get(surface_finish, surface_finish)
|
||||
parts.append(finish_desc)
|
||||
|
||||
# 8. 用途场景
|
||||
if usage_scene:
|
||||
scene_desc = scene_map.get(usage_scene, usage_scene)
|
||||
parts.append(scene_desc)
|
||||
|
||||
# 9. 尺寸
|
||||
if size_spec:
|
||||
parts.append(f"size approximately {size_spec}")
|
||||
|
||||
# 10. 用户描述
|
||||
if user_prompt:
|
||||
parts.append(f"design concept: {user_prompt}")
|
||||
|
||||
# 11. 视角
|
||||
view_desc = view_map.get(view_name, "three-quarter view")
|
||||
parts.append(view_desc)
|
||||
|
||||
# 12. 质量后缀
|
||||
parts.append(quality_suffix)
|
||||
|
||||
return ", ".join(parts)
|
||||
@@ -56,3 +56,19 @@ def get_current_user(
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_admin_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
获取当前管理员用户
|
||||
|
||||
验证当前用户是否为管理员,非管理员抛出 403
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="权限不足,需要管理员权限"
|
||||
)
|
||||
return current_user
|
||||
|
||||
Reference in New Issue
Block a user