Files
YuShiSheJiShi/backend/app/services/design_service.py
bb84747917 feat(ai): 升级AI生图模型及多视角一致性支持
- 将默认AI生图模型升级为flux-dev及seedream-5.0版本
- SiliconFlow模型由FLUX.1-dev切换为Kolors,优化调用参数和返回值
- 火山引擎Seedream升级至5.0 lite版本,支持多视角参考图传入
- 设计图片字段由字符串改为Text扩展URL长度限制
- 设计图下载支持远程URL重定向和本地文件兼容
- 生成AI图片时多视角保持风格一致,SiliconFlow复用seed,Seedream传参考图
- 后台配置界面更改模型名称及价格显示,新增API Key状态检测
- 前端照片下载从链接改为按钮,远程文件新窗口打开
- 设计相关接口支持较长请求超时,下载走API路径无/api前缀
- 前端页面兼容驼峰与下划线格式URL参数识别
- 用户中心设计图下载支持本地文件Token授权下载
- 初始化数据库新增完整表结构与约束,适配新版设计业务逻辑
2026-03-27 17:39:01 +08:00

330 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
设计服务
处理设计相关的业务逻辑,支持 AI 多视角生图 + mock 降级
"""
import os
import logging
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import desc
from ..models import Design, DesignImage, Category, SubType, Color
from ..schemas import DesignCreate
from ..config import settings
from .mock_generator import generate_mock_design
from .prompt_builder import get_views_for_category, build_prompt
from . import ai_generator
logger = logging.getLogger(__name__)
def _has_ai_key() -> bool:
"""检查是否配置了 AI API Key从数据库配置优先读取"""
from .config_service import get_ai_config
ai_config = get_ai_config()
model = ai_config.get("AI_IMAGE_MODEL", "flux-dev")
if model in ("seedream-5.0", "seedream-4.5"):
return bool(ai_config.get("VOLCENGINE_API_KEY"))
return bool(ai_config.get("SILICONFLOW_API_KEY"))
async def create_design_async(db: Session, user_id: int, design_data: DesignCreate) -> Design:
"""
创建设计记录(异步版本,支持 AI 多视角生图)
流程:
1. 创建 Design 记录status=generating
2. 获取品类视角列表
3. 循环每个视角:构建 prompt → 调用 AI 生图 → 下载保存 → 创建 DesignImage
4. 第一张效果图 URL 存入 design.image_url兼容旧逻辑
5. 更新 status=completed
6. 失败时降级到 mock_generator
"""
# 获取关联信息
category = db.query(Category).filter(Category.id == design_data.category_id).first()
if not category:
raise ValueError(f"品类不存在: {design_data.category_id}")
sub_type = None
if design_data.sub_type_id:
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
color = None
if design_data.color_id:
color = db.query(Color).filter(Color.id == design_data.color_id).first()
# 创建设计记录
design = Design(
user_id=user_id,
category_id=design_data.category_id,
sub_type_id=design_data.sub_type_id,
color_id=design_data.color_id,
prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
status="generating"
)
db.add(design)
db.flush() # 获取 ID
# 尝试 AI 生图
if _has_ai_key():
try:
await _generate_ai_images(db, design, category, sub_type, color, design_data)
db.commit()
db.refresh(design)
return design
except Exception as e:
logger.error(f"AI 生图全部失败,降级到 mock: {e}")
db.rollback()
# 重新查询,因为 rollback 后 ORM 对象可能失效
design = db.query(Design).filter(Design.id == design.id).first()
if not design:
# rollback 导致 design 也没了,重新创建
design = Design(
user_id=user_id,
category_id=design_data.category_id,
sub_type_id=design_data.sub_type_id,
color_id=design_data.color_id,
prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
status="generating"
)
db.add(design)
db.flush()
# 降级到 mock 生成
_generate_mock_fallback(db, design, category, sub_type, color, design_data)
db.commit()
db.refresh(design)
return design
async def _generate_ai_images(
db: Session,
design: Design,
category,
sub_type,
color,
design_data: DesignCreate,
) -> None:
"""使用 AI 模型为每个视角生成图片
多视角一致性策略:
- SiliconFlow Kolors: 通过复用 seed 保持一致
- Seedream 5.0 lite: 通过参考图image参数保持一致
"""
views = get_views_for_category(category.name)
from .config_service import get_ai_config
ai_config = get_ai_config()
model = ai_config.get("AI_IMAGE_MODEL", "flux-dev")
shared_seed = None # Kolors 用: 第一张图的 seed
first_remote_url = None # Seedream 用: 第一张图的远程 URL 作为参考图
for idx, view_name in enumerate(views):
# 构建 prompt
prompt_text = build_prompt(
category_name=category.name,
view_name=view_name,
sub_type_name=sub_type.name if sub_type else None,
color_name=color.name if color else None,
user_prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
)
# 调用 AI 生图
# 后续视角传入 seedKolors或参考图 URLSeedream保持一致性
ref_url = first_remote_url if idx > 0 else None
remote_url, returned_seed = await ai_generator.generate_image(
prompt_text, model, seed=shared_seed, ref_image_url=ref_url
)
# 第一张图保存信息供后续视角复用
if idx == 0:
first_remote_url = remote_url
if returned_seed is not None:
shared_seed = returned_seed
logger.info(f"多视角生图: seed={shared_seed}, ref_url={remote_url[:60]}...")
# 直接使用远程 URL不下载到本地节省服务器存储空间
image_url = remote_url
# 创建 DesignImage 记录
design_image = DesignImage(
design_id=design.id,
view_name=view_name,
image_url=image_url,
model_used=model,
prompt_used=prompt_text,
sort_order=idx,
)
db.add(design_image)
# 第一张图(效果图)存入 design.image_url 兼容旧逻辑
if idx == 0:
design.image_url = image_url
design.status = "completed"
def _generate_mock_fallback(
db: Session,
design: Design,
category,
sub_type,
color,
design_data: DesignCreate,
) -> None:
"""降级使用 mock 生成器"""
save_path = os.path.join(settings.UPLOAD_DIR, "designs", f"{design.id}.png")
image_url = generate_mock_design(
category_name=category.name,
sub_type_name=sub_type.name if sub_type else None,
color_name=color.name if color else None,
prompt=design_data.prompt,
save_path=save_path,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
)
design.image_url = image_url
design.status = "completed"
logger.info(f"Mock 降级生成完成: design_id={design.id}")
def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design:
"""
同步版本创建设计(兼容旧调用,仅用 mock
"""
category = db.query(Category).filter(Category.id == design_data.category_id).first()
if not category:
raise ValueError(f"品类不存在: {design_data.category_id}")
sub_type = None
if design_data.sub_type_id:
sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first()
color = None
if design_data.color_id:
color = db.query(Color).filter(Color.id == design_data.color_id).first()
design = Design(
user_id=user_id,
category_id=design_data.category_id,
sub_type_id=design_data.sub_type_id,
color_id=design_data.color_id,
prompt=design_data.prompt,
carving_technique=design_data.carving_technique,
design_style=design_data.design_style,
motif=design_data.motif,
size_spec=design_data.size_spec,
surface_finish=design_data.surface_finish,
usage_scene=design_data.usage_scene,
status="generating"
)
db.add(design)
db.flush()
_generate_mock_fallback(db, design, category, sub_type, color, design_data)
db.commit()
db.refresh(design)
return design
def get_user_designs(
db: Session,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Tuple[List[Design], int]:
"""
分页查询用户设计历史
Returns:
(设计列表, 总数)
"""
query = db.query(Design).filter(Design.user_id == user_id)
# 获取总数
total = query.count()
# 分页查询,按创建时间倒序
offset = (page - 1) * page_size
designs = query.order_by(desc(Design.created_at)).offset(offset).limit(page_size).all()
return designs, total
def get_design_by_id(db: Session, design_id: int, user_id: int) -> Optional[Design]:
"""
获取单个设计
只返回属于该用户的设计
"""
return db.query(Design).filter(
Design.id == design_id,
Design.user_id == user_id
).first()
def delete_design(db: Session, design_id: int, user_id: int) -> bool:
"""
删除设计
1. 查找设计(必须属于该用户)
2. 删除图片文件
3. 删除数据库记录
Returns:
是否删除成功
"""
design = db.query(Design).filter(
Design.id == design_id,
Design.user_id == user_id
).first()
if not design:
return False
# 删除主图片文件
if design.image_url:
file_path = design.image_url.lstrip("/")
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception:
pass
# 删除多视角图片文件
for img in design.images:
if img.image_url:
fp = img.image_url.lstrip("/")
if os.path.exists(fp):
try:
os.remove(fp)
except Exception:
pass
# 删除数据库记录
db.delete(design)
db.commit()
return True