""" 设计服务 处理设计相关的业务逻辑,支持 AI 多视角生图 + mock 降级 """ import os import logging from typing import List, Optional, Tuple from sqlalchemy.orm import Session from sqlalchemy import desc from ..models import Design, DesignImage, Category, SubType, Color from ..schemas import DesignCreate from ..config import settings from .mock_generator import generate_mock_design from .prompt_builder import get_views_for_category, build_prompt from . import ai_generator logger = logging.getLogger(__name__) def _has_ai_key() -> bool: """检查是否配置了 AI API Key(从数据库配置优先读取)""" from .config_service import get_ai_config ai_config = get_ai_config() model = ai_config.get("AI_IMAGE_MODEL", "flux-dev") if model in ("seedream-5.0", "seedream-4.5"): return bool(ai_config.get("VOLCENGINE_API_KEY")) return bool(ai_config.get("SILICONFLOW_API_KEY")) async def create_design_async(db: Session, user_id: int, design_data: DesignCreate) -> Design: """ 创建设计记录(异步版本,支持 AI 多视角生图) 流程: 1. 创建 Design 记录(status=generating) 2. 获取品类视角列表 3. 循环每个视角:构建 prompt → 调用 AI 生图 → 下载保存 → 创建 DesignImage 4. 第一张效果图 URL 存入 design.image_url(兼容旧逻辑) 5. 更新 status=completed 6. 失败时降级到 mock_generator """ # 获取关联信息 category = db.query(Category).filter(Category.id == design_data.category_id).first() if not category: raise ValueError(f"品类不存在: {design_data.category_id}") sub_type = None if design_data.sub_type_id: sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first() color = None if design_data.color_id: color = db.query(Color).filter(Color.id == design_data.color_id).first() # 创建设计记录 design = Design( user_id=user_id, category_id=design_data.category_id, sub_type_id=design_data.sub_type_id, color_id=design_data.color_id, prompt=design_data.prompt, carving_technique=design_data.carving_technique, design_style=design_data.design_style, motif=design_data.motif, size_spec=design_data.size_spec, surface_finish=design_data.surface_finish, usage_scene=design_data.usage_scene, status="generating" ) db.add(design) db.flush() # 获取 ID # 尝试 AI 生图 if _has_ai_key(): try: await _generate_ai_images(db, design, category, sub_type, color, design_data) db.commit() db.refresh(design) return design except Exception as e: logger.error(f"AI 生图全部失败,降级到 mock: {e}") db.rollback() # 重新查询,因为 rollback 后 ORM 对象可能失效 design = db.query(Design).filter(Design.id == design.id).first() if not design: # rollback 导致 design 也没了,重新创建 design = Design( user_id=user_id, category_id=design_data.category_id, sub_type_id=design_data.sub_type_id, color_id=design_data.color_id, prompt=design_data.prompt, carving_technique=design_data.carving_technique, design_style=design_data.design_style, motif=design_data.motif, size_spec=design_data.size_spec, surface_finish=design_data.surface_finish, usage_scene=design_data.usage_scene, status="generating" ) db.add(design) db.flush() # 降级到 mock 生成 _generate_mock_fallback(db, design, category, sub_type, color, design_data) db.commit() db.refresh(design) return design async def _generate_ai_images( db: Session, design: Design, category, sub_type, color, design_data: DesignCreate, ) -> None: """使用 AI 模型为每个视角生成图片 多视角一致性策略: - SiliconFlow Kolors: 通过复用 seed 保持一致 - Seedream 5.0 lite: 通过参考图(image参数)保持一致 """ views = get_views_for_category(category.name) from .config_service import get_ai_config ai_config = get_ai_config() model = ai_config.get("AI_IMAGE_MODEL", "flux-dev") shared_seed = None # Kolors 用: 第一张图的 seed first_remote_url = None # Seedream 用: 第一张图的远程 URL 作为参考图 for idx, view_name in enumerate(views): # 构建 prompt prompt_text = build_prompt( category_name=category.name, view_name=view_name, sub_type_name=sub_type.name if sub_type else None, color_name=color.name if color else None, user_prompt=design_data.prompt, carving_technique=design_data.carving_technique, design_style=design_data.design_style, motif=design_data.motif, size_spec=design_data.size_spec, surface_finish=design_data.surface_finish, usage_scene=design_data.usage_scene, ) # 调用 AI 生图 # 后续视角传入 seed(Kolors)或参考图 URL(Seedream)保持一致性 ref_url = first_remote_url if idx > 0 else None remote_url, returned_seed = await ai_generator.generate_image( prompt_text, model, seed=shared_seed, ref_image_url=ref_url ) # 第一张图保存信息供后续视角复用 if idx == 0: first_remote_url = remote_url if returned_seed is not None: shared_seed = returned_seed logger.info(f"多视角生图: seed={shared_seed}, ref_url={remote_url[:60]}...") # 直接使用远程 URL,不下载到本地(节省服务器存储空间) image_url = remote_url # 创建 DesignImage 记录 design_image = DesignImage( design_id=design.id, view_name=view_name, image_url=image_url, model_used=model, prompt_used=prompt_text, sort_order=idx, ) db.add(design_image) # 第一张图(效果图)存入 design.image_url 兼容旧逻辑 if idx == 0: design.image_url = image_url design.status = "completed" def _generate_mock_fallback( db: Session, design: Design, category, sub_type, color, design_data: DesignCreate, ) -> None: """降级使用 mock 生成器""" save_path = os.path.join(settings.UPLOAD_DIR, "designs", f"{design.id}.png") image_url = generate_mock_design( category_name=category.name, sub_type_name=sub_type.name if sub_type else None, color_name=color.name if color else None, prompt=design_data.prompt, save_path=save_path, carving_technique=design_data.carving_technique, design_style=design_data.design_style, motif=design_data.motif, size_spec=design_data.size_spec, surface_finish=design_data.surface_finish, usage_scene=design_data.usage_scene, ) design.image_url = image_url design.status = "completed" logger.info(f"Mock 降级生成完成: design_id={design.id}") def create_design(db: Session, user_id: int, design_data: DesignCreate) -> Design: """ 同步版本创建设计(兼容旧调用,仅用 mock) """ category = db.query(Category).filter(Category.id == design_data.category_id).first() if not category: raise ValueError(f"品类不存在: {design_data.category_id}") sub_type = None if design_data.sub_type_id: sub_type = db.query(SubType).filter(SubType.id == design_data.sub_type_id).first() color = None if design_data.color_id: color = db.query(Color).filter(Color.id == design_data.color_id).first() design = Design( user_id=user_id, category_id=design_data.category_id, sub_type_id=design_data.sub_type_id, color_id=design_data.color_id, prompt=design_data.prompt, carving_technique=design_data.carving_technique, design_style=design_data.design_style, motif=design_data.motif, size_spec=design_data.size_spec, surface_finish=design_data.surface_finish, usage_scene=design_data.usage_scene, status="generating" ) db.add(design) db.flush() _generate_mock_fallback(db, design, category, sub_type, color, design_data) db.commit() db.refresh(design) return design def get_user_designs( db: Session, user_id: int, page: int = 1, page_size: int = 20 ) -> Tuple[List[Design], int]: """ 分页查询用户设计历史 Returns: (设计列表, 总数) """ query = db.query(Design).filter(Design.user_id == user_id) # 获取总数 total = query.count() # 分页查询,按创建时间倒序 offset = (page - 1) * page_size designs = query.order_by(desc(Design.created_at)).offset(offset).limit(page_size).all() return designs, total def get_design_by_id(db: Session, design_id: int, user_id: int) -> Optional[Design]: """ 获取单个设计 只返回属于该用户的设计 """ return db.query(Design).filter( Design.id == design_id, Design.user_id == user_id ).first() def delete_design(db: Session, design_id: int, user_id: int) -> bool: """ 删除设计 1. 查找设计(必须属于该用户) 2. 删除图片文件 3. 删除数据库记录 Returns: 是否删除成功 """ design = db.query(Design).filter( Design.id == design_id, Design.user_id == user_id ).first() if not design: return False # 删除主图片文件 if design.image_url: file_path = design.image_url.lstrip("/") if os.path.exists(file_path): try: os.remove(file_path) except Exception: pass # 删除多视角图片文件 for img in design.images: if img.image_url: fp = img.image_url.lstrip("/") if os.path.exists(fp): try: os.remove(fp) except Exception: pass # 删除数据库记录 db.delete(design) db.commit() return True