527 lines
18 KiB
Python
527 lines
18 KiB
Python
"""共享API Hub路由"""
|
||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import func as sa_func
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
from datetime import datetime, timedelta
|
||
import time
|
||
import hashlib
|
||
|
||
from database import get_db
|
||
from config import SECRET_KEY
|
||
from models.user import User
|
||
from models.system_config import SystemConfig
|
||
from models.shared_api import SharedApiCategory, SharedApi, SharedApiLog
|
||
from routers.auth import get_current_user, get_admin_user
|
||
|
||
router = APIRouter()
|
||
|
||
# ========== 加密工具 ==========
|
||
|
||
_fernet = None
|
||
|
||
def _get_fernet():
|
||
global _fernet
|
||
if _fernet is None:
|
||
from cryptography.fernet import Fernet
|
||
import base64
|
||
# 从SECRET_KEY派生一个Fernet兼容的key
|
||
key = hashlib.sha256(SECRET_KEY.encode()).digest()
|
||
_fernet = Fernet(base64.urlsafe_b64encode(key))
|
||
return _fernet
|
||
|
||
def encrypt_key(plain: str) -> str:
|
||
if not plain:
|
||
return ""
|
||
return _get_fernet().encrypt(plain.encode()).decode()
|
||
|
||
def decrypt_key(encrypted: str) -> str:
|
||
if not encrypted:
|
||
return ""
|
||
try:
|
||
return _get_fernet().decrypt(encrypted.encode()).decode()
|
||
except Exception:
|
||
return ""
|
||
|
||
def mask_key(encrypted: str) -> str:
|
||
"""脱敏显示"""
|
||
plain = decrypt_key(encrypted)
|
||
if not plain:
|
||
return ""
|
||
if len(plain) <= 8:
|
||
return plain[:2] + "***"
|
||
return plain[:4] + "****" + plain[-4:]
|
||
|
||
|
||
# ========== Hub访问密码机制 ==========
|
||
|
||
def _get_hub_password(db: Session) -> str:
|
||
cfg = db.query(SystemConfig).filter(SystemConfig.key == "api_hub_password").first()
|
||
return cfg.value if cfg else ""
|
||
|
||
def _hub_password_version(db: Session) -> str:
|
||
"""返回密码哈希前8位作为版本标识,密码变更后旧token自动失效"""
|
||
pwd = _get_hub_password(db)
|
||
return pwd[:8] if pwd else "none"
|
||
|
||
def _hash_password(pwd: str) -> str:
|
||
return hashlib.sha256(pwd.encode()).hexdigest()
|
||
|
||
def _create_hub_token(user_id: int, pwd_ver: str = "none") -> str:
|
||
"""创建Hub访问令牌(简单签名,2小时有效)"""
|
||
from jose import jwt
|
||
exp = datetime.utcnow() + timedelta(hours=2)
|
||
return jwt.encode({"sub": str(user_id), "hub": True, "pv": pwd_ver, "exp": exp}, SECRET_KEY, algorithm="HS256")
|
||
|
||
def verify_hub_access(
|
||
x_hub_token: Optional[str] = Header(None),
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""验证用户登录 + Hub访问令牌"""
|
||
if not x_hub_token:
|
||
raise HTTPException(status_code=403, detail="需要API Hub访问权限,请先验证密码")
|
||
from jose import jwt, JWTError
|
||
try:
|
||
payload = jwt.decode(x_hub_token, SECRET_KEY, algorithms=["HS256"])
|
||
if not payload.get("hub"):
|
||
raise HTTPException(status_code=403, detail="无效的Hub令牌")
|
||
# 检查密码版本是否匹配
|
||
token_pv = payload.get("pv", "")
|
||
current_pv = _hub_password_version(db)
|
||
if token_pv != current_pv:
|
||
raise HTTPException(status_code=403, detail="密码已变更,请重新验证")
|
||
except JWTError:
|
||
raise HTTPException(status_code=403, detail="Hub令牌已过期,请重新验证密码")
|
||
return current_user
|
||
|
||
|
||
# ========== Schemas ==========
|
||
|
||
class HubAuthRequest(BaseModel):
|
||
password: str
|
||
|
||
class CategoryCreate(BaseModel):
|
||
name: str
|
||
icon: str = ""
|
||
|
||
class CategoryUpdate(BaseModel):
|
||
name: Optional[str] = None
|
||
icon: Optional[str] = None
|
||
sort_order: Optional[int] = None
|
||
is_active: Optional[bool] = None
|
||
|
||
class ApiCreate(BaseModel):
|
||
category_id: Optional[int] = None
|
||
name: str
|
||
description: str = ""
|
||
base_url: str = ""
|
||
doc_url: str = ""
|
||
auth_type: str = "none"
|
||
api_key: str = "" # 明文传入,后端加密存储
|
||
api_key_header: str = "Authorization"
|
||
health_check_url: str = ""
|
||
tags: str = ""
|
||
|
||
class ApiUpdate(BaseModel):
|
||
category_id: Optional[int] = None
|
||
name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
base_url: Optional[str] = None
|
||
doc_url: Optional[str] = None
|
||
auth_type: Optional[str] = None
|
||
api_key: Optional[str] = None
|
||
api_key_header: Optional[str] = None
|
||
health_check_url: Optional[str] = None
|
||
tags: Optional[str] = None
|
||
is_active: Optional[bool] = None
|
||
|
||
class ApiTestRequest(BaseModel):
|
||
method: str = "GET"
|
||
path: str = ""
|
||
body: str = ""
|
||
headers: dict = {}
|
||
|
||
|
||
# ========== 密码认证接口 ==========
|
||
|
||
@router.post("/auth")
|
||
def hub_auth(
|
||
data: HubAuthRequest,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(get_current_user),
|
||
):
|
||
"""验证Hub访问密码"""
|
||
stored = _get_hub_password(db)
|
||
if not stored:
|
||
raise HTTPException(status_code=400, detail="管理员尚未设置访问密码")
|
||
if _hash_password(data.password) != stored:
|
||
raise HTTPException(status_code=403, detail="密码错误")
|
||
token = _create_hub_token(user.id, _hub_password_version(db))
|
||
return {"hub_token": token, "expires_in": 7200}
|
||
|
||
@router.get("/check-password")
|
||
def check_password_set(
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(get_current_user),
|
||
):
|
||
"""检查是否已设置访问密码"""
|
||
stored = _get_hub_password(db)
|
||
return {"has_password": bool(stored)}
|
||
|
||
@router.put("/admin/password")
|
||
def set_hub_password(
|
||
data: HubAuthRequest,
|
||
db: Session = Depends(get_db),
|
||
admin: User = Depends(get_admin_user),
|
||
):
|
||
"""管理员设置Hub访问密码"""
|
||
if len(data.password) < 4:
|
||
raise HTTPException(status_code=400, detail="密码至少4位")
|
||
hashed = _hash_password(data.password)
|
||
cfg = db.query(SystemConfig).filter(SystemConfig.key == "api_hub_password").first()
|
||
if cfg:
|
||
cfg.value = hashed
|
||
else:
|
||
cfg = SystemConfig(key="api_hub_password", value=hashed, description="API Hub访问密码")
|
||
db.add(cfg)
|
||
db.commit()
|
||
return {"message": "密码设置成功"}
|
||
|
||
@router.get("/admin/password-status")
|
||
def get_password_status(
|
||
db: Session = Depends(get_db),
|
||
admin: User = Depends(get_admin_user),
|
||
):
|
||
"""管理员查看密码是否已设置"""
|
||
stored = _get_hub_password(db)
|
||
return {"has_password": bool(stored)}
|
||
|
||
|
||
# ========== 分类管理 ==========
|
||
|
||
@router.get("/categories")
|
||
def list_categories(
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
cats = db.query(SharedApiCategory).order_by(SharedApiCategory.sort_order, SharedApiCategory.id).all()
|
||
return [
|
||
{"id": c.id, "name": c.name, "icon": c.icon, "sort_order": c.sort_order, "is_active": c.is_active,
|
||
"api_count": db.query(sa_func.count(SharedApi.id)).filter(SharedApi.category_id == c.id).scalar() or 0}
|
||
for c in cats
|
||
]
|
||
|
||
@router.post("/categories")
|
||
def create_category(
|
||
data: CategoryCreate,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
existing = db.query(SharedApiCategory).filter(SharedApiCategory.name == data.name).first()
|
||
if existing:
|
||
raise HTTPException(status_code=400, detail="分类名称已存在")
|
||
max_order = db.query(sa_func.max(SharedApiCategory.sort_order)).scalar() or 0
|
||
cat = SharedApiCategory(name=data.name, icon=data.icon, sort_order=max_order + 1)
|
||
db.add(cat)
|
||
db.commit()
|
||
db.refresh(cat)
|
||
return {"id": cat.id, "name": cat.name, "icon": cat.icon}
|
||
|
||
@router.put("/categories/{cat_id}")
|
||
def update_category(
|
||
cat_id: int,
|
||
data: CategoryUpdate,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
cat = db.query(SharedApiCategory).filter(SharedApiCategory.id == cat_id).first()
|
||
if not cat:
|
||
raise HTTPException(status_code=404, detail="分类不存在")
|
||
if data.name is not None:
|
||
cat.name = data.name
|
||
if data.icon is not None:
|
||
cat.icon = data.icon
|
||
if data.sort_order is not None:
|
||
cat.sort_order = data.sort_order
|
||
if data.is_active is not None:
|
||
cat.is_active = data.is_active
|
||
db.commit()
|
||
return {"message": "更新成功"}
|
||
|
||
@router.delete("/categories/{cat_id}")
|
||
def delete_category(
|
||
cat_id: int,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
cat = db.query(SharedApiCategory).filter(SharedApiCategory.id == cat_id).first()
|
||
if not cat:
|
||
raise HTTPException(status_code=404, detail="分类不存在")
|
||
# 将该分类下的API设为未分类
|
||
db.query(SharedApi).filter(SharedApi.category_id == cat_id).update({SharedApi.category_id: None})
|
||
db.delete(cat)
|
||
db.commit()
|
||
return {"message": "删除成功"}
|
||
|
||
|
||
# ========== API CRUD ==========
|
||
|
||
def _api_to_dict(api, db=None):
|
||
d = {
|
||
"id": api.id, "category_id": api.category_id,
|
||
"name": api.name, "description": api.description,
|
||
"base_url": api.base_url, "doc_url": api.doc_url,
|
||
"auth_type": api.auth_type,
|
||
"api_key_masked": mask_key(api.api_key_encrypted),
|
||
"api_key_plain": decrypt_key(api.api_key_encrypted) if api.api_key_encrypted else "",
|
||
"has_api_key": bool(api.api_key_encrypted),
|
||
"api_key_header": api.api_key_header,
|
||
"health_check_url": api.health_check_url,
|
||
"last_check_time": api.last_check_time.isoformat() if api.last_check_time else None,
|
||
"last_check_status": api.last_check_status,
|
||
"added_by": api.added_by, "tags": api.tags,
|
||
"call_count": api.call_count, "is_active": api.is_active,
|
||
"created_at": api.created_at.isoformat() if api.created_at else None,
|
||
"updated_at": api.updated_at.isoformat() if api.updated_at else None,
|
||
}
|
||
return d
|
||
|
||
@router.get("/list")
|
||
def list_apis(
|
||
keyword: Optional[str] = None,
|
||
category_id: Optional[int] = None,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
query = db.query(SharedApi).filter(SharedApi.is_active == True)
|
||
if keyword:
|
||
kw = f"%{keyword}%"
|
||
query = query.filter(
|
||
(SharedApi.name.like(kw)) | (SharedApi.description.like(kw)) | (SharedApi.tags.like(kw))
|
||
)
|
||
if category_id is not None:
|
||
query = query.filter(SharedApi.category_id == category_id)
|
||
apis = query.order_by(SharedApi.call_count.desc(), SharedApi.id.desc()).all()
|
||
return {"items": [_api_to_dict(a) for a in apis], "total": len(apis)}
|
||
|
||
@router.post("/")
|
||
def create_api(
|
||
data: ApiCreate,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
api = SharedApi(
|
||
category_id=data.category_id, name=data.name, description=data.description,
|
||
base_url=data.base_url, doc_url=data.doc_url,
|
||
auth_type=data.auth_type,
|
||
api_key_encrypted=encrypt_key(data.api_key) if data.api_key else "",
|
||
api_key_header=data.api_key_header,
|
||
health_check_url=data.health_check_url,
|
||
tags=data.tags, added_by=user.id,
|
||
)
|
||
db.add(api)
|
||
db.commit()
|
||
db.refresh(api)
|
||
return _api_to_dict(api)
|
||
|
||
@router.put("/{api_id}")
|
||
def update_api(
|
||
api_id: int,
|
||
data: ApiUpdate,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||
if not api:
|
||
raise HTTPException(status_code=404, detail="API不存在")
|
||
if data.category_id is not None:
|
||
api.category_id = data.category_id
|
||
if data.name is not None:
|
||
api.name = data.name
|
||
if data.description is not None:
|
||
api.description = data.description
|
||
if data.base_url is not None:
|
||
api.base_url = data.base_url
|
||
if data.doc_url is not None:
|
||
api.doc_url = data.doc_url
|
||
if data.auth_type is not None:
|
||
api.auth_type = data.auth_type
|
||
if data.api_key is not None and data.api_key != "":
|
||
api.api_key_encrypted = encrypt_key(data.api_key)
|
||
if data.api_key_header is not None:
|
||
api.api_key_header = data.api_key_header
|
||
if data.health_check_url is not None:
|
||
api.health_check_url = data.health_check_url
|
||
if data.tags is not None:
|
||
api.tags = data.tags
|
||
if data.is_active is not None:
|
||
api.is_active = data.is_active
|
||
db.commit()
|
||
db.refresh(api)
|
||
return _api_to_dict(api)
|
||
|
||
@router.delete("/{api_id}")
|
||
def delete_api(
|
||
api_id: int,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||
if not api:
|
||
raise HTTPException(status_code=404, detail="API不存在")
|
||
db.query(SharedApiLog).filter(SharedApiLog.api_id == api_id).delete()
|
||
db.delete(api)
|
||
db.commit()
|
||
return {"message": "删除成功"}
|
||
|
||
|
||
# ========== API测试 ==========
|
||
|
||
@router.post("/{api_id}/test")
|
||
async def test_api(
|
||
api_id: int,
|
||
data: ApiTestRequest,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
"""在线测试API(后端代理请求)"""
|
||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||
if not api:
|
||
raise HTTPException(status_code=404, detail="API不存在")
|
||
|
||
url = api.base_url.rstrip("/")
|
||
if data.path:
|
||
url = url + "/" + data.path.lstrip("/")
|
||
|
||
headers = dict(data.headers) if data.headers else {}
|
||
# 注入认证信息
|
||
if api.auth_type != "none" and api.api_key_encrypted:
|
||
key = decrypt_key(api.api_key_encrypted)
|
||
if key:
|
||
if api.auth_type == "bearer":
|
||
headers[api.api_key_header] = f"Bearer {key}"
|
||
elif api.auth_type == "api_key":
|
||
headers[api.api_key_header] = key
|
||
elif api.auth_type == "basic":
|
||
import base64
|
||
headers["Authorization"] = f"Basic {base64.b64encode(key.encode()).decode()}"
|
||
|
||
import httpx
|
||
start = time.time()
|
||
try:
|
||
async with httpx.AsyncClient(timeout=15) as client:
|
||
if data.method.upper() == "POST":
|
||
resp = await client.post(url, headers=headers, content=data.body or None)
|
||
elif data.method.upper() == "PUT":
|
||
resp = await client.put(url, headers=headers, content=data.body or None)
|
||
elif data.method.upper() == "DELETE":
|
||
resp = await client.delete(url, headers=headers)
|
||
else:
|
||
resp = await client.get(url, headers=headers)
|
||
elapsed = int((time.time() - start) * 1000)
|
||
# 记录日志
|
||
log = SharedApiLog(
|
||
api_id=api_id, user_id=user.id, action="test",
|
||
request_url=url, response_status=resp.status_code, response_time_ms=elapsed,
|
||
)
|
||
db.add(log)
|
||
api.call_count = (api.call_count or 0) + 1
|
||
db.commit()
|
||
# 限制返回体大小
|
||
body = resp.text[:5000] if len(resp.text) > 5000 else resp.text
|
||
return {
|
||
"status_code": resp.status_code,
|
||
"response_time_ms": elapsed,
|
||
"headers": dict(resp.headers),
|
||
"body": body,
|
||
}
|
||
except Exception as e:
|
||
elapsed = int((time.time() - start) * 1000)
|
||
log = SharedApiLog(
|
||
api_id=api_id, user_id=user.id, action="test",
|
||
request_url=url, response_status=0, response_time_ms=elapsed,
|
||
)
|
||
db.add(log)
|
||
db.commit()
|
||
return {"status_code": 0, "response_time_ms": elapsed, "headers": {}, "body": str(e)}
|
||
|
||
|
||
@router.post("/{api_id}/health-check")
|
||
async def health_check(
|
||
api_id: int,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
"""健康检查"""
|
||
api = db.query(SharedApi).filter(SharedApi.id == api_id).first()
|
||
if not api:
|
||
raise HTTPException(status_code=404, detail="API不存在")
|
||
|
||
check_url = api.health_check_url or api.base_url
|
||
import httpx
|
||
start = time.time()
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10) as client:
|
||
resp = await client.get(check_url)
|
||
elapsed = int((time.time() - start) * 1000)
|
||
status = "ok" if resp.status_code < 400 else "error"
|
||
except Exception:
|
||
elapsed = int((time.time() - start) * 1000)
|
||
status = "error"
|
||
|
||
api.last_check_time = datetime.utcnow()
|
||
api.last_check_status = status
|
||
log = SharedApiLog(
|
||
api_id=api_id, user_id=user.id, action="health_check",
|
||
request_url=check_url, response_status=resp.status_code if status == "ok" else 0,
|
||
response_time_ms=elapsed,
|
||
)
|
||
db.add(log)
|
||
db.commit()
|
||
return {"status": status, "response_time_ms": elapsed}
|
||
|
||
|
||
# ========== 日志与统计 ==========
|
||
|
||
@router.get("/{api_id}/logs")
|
||
def get_api_logs(
|
||
api_id: int,
|
||
limit: int = 20,
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
logs = (
|
||
db.query(SharedApiLog)
|
||
.filter(SharedApiLog.api_id == api_id)
|
||
.order_by(SharedApiLog.id.desc())
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
return [
|
||
{
|
||
"id": l.id, "action": l.action, "request_url": l.request_url,
|
||
"response_status": l.response_status, "response_time_ms": l.response_time_ms,
|
||
"user_id": l.user_id,
|
||
"created_at": l.created_at.isoformat() if l.created_at else None,
|
||
}
|
||
for l in logs
|
||
]
|
||
|
||
@router.get("/stats")
|
||
def get_stats(
|
||
db: Session = Depends(get_db),
|
||
user: User = Depends(verify_hub_access),
|
||
):
|
||
total_apis = db.query(sa_func.count(SharedApi.id)).filter(SharedApi.is_active == True).scalar() or 0
|
||
total_calls = db.query(sa_func.sum(SharedApi.call_count)).scalar() or 0
|
||
total_categories = db.query(sa_func.count(SharedApiCategory.id)).scalar() or 0
|
||
healthy = db.query(sa_func.count(SharedApi.id)).filter(SharedApi.last_check_status == "ok").scalar() or 0
|
||
return {
|
||
"total_apis": total_apis,
|
||
"total_calls": total_calls,
|
||
"total_categories": total_categories,
|
||
"healthy_count": healthy,
|
||
}
|