mirror of
https://git.oceanpay.cc/danial/kami_apple_exchage.git
synced 2025-12-18 21:23:49 +00:00
- 新增 CODEBUDDY.md、GEMINI.md、GEMINI_CN.md 等项目文档 - 更新 Dockerfile 和其他配置文件 - 优化部分代码结构,如 orders.py、tasks.py 等 - 新增 .dockerignore 文件
402 lines
12 KiB
Python
402 lines
12 KiB
Python
"""
|
||
FastAPI中间件模块
|
||
包含请求日志、安全头、CORS、OpenTelemetry等中间件配置
|
||
"""
|
||
|
||
import time
|
||
import traceback
|
||
import uuid
|
||
from typing import Callable, Awaitable, Optional
|
||
|
||
from fastapi import FastAPI, Request, Response
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from opentelemetry import trace, metrics
|
||
|
||
from app.core.config import get_settings
|
||
from app.core.log import get_logger, LogContext, request_id_var
|
||
|
||
settings = get_settings()
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||
"""请求日志中间件 - 集成OpenTelemetry和结构化日志"""
|
||
|
||
async def dispatch(
|
||
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||
) -> Response:
|
||
# 生成请求ID
|
||
current_span = trace.get_current_span()
|
||
if current_span.is_recording():
|
||
# 使用OpenTelemetry trace ID作为请求ID
|
||
request_id = format(current_span.get_span_context().trace_id, "032x")
|
||
else:
|
||
# 如果没有trace,生成UUID
|
||
request_id = str(uuid.uuid4())
|
||
|
||
request.state.request_id = request_id
|
||
start_time = time.time()
|
||
|
||
# 使用日志上下文
|
||
with LogContext(request_id=request_id):
|
||
# 记录请求开始
|
||
# logger.info(
|
||
# "请求开始",
|
||
# method=str(request.method),
|
||
# url=str(request.url),
|
||
# client_ip=str(request.client.host) if request.client else "unknown",
|
||
# user_agent=str(request.headers.get("user-agent", "unknown")),
|
||
# )
|
||
|
||
try:
|
||
# 执行请求
|
||
response = await call_next(request)
|
||
process_time = time.time() - start_time
|
||
|
||
# 设置响应头
|
||
response.headers["X-Request-ID"] = request_id
|
||
response.headers["X-Process-Time"] = f"{process_time:.3f}"
|
||
|
||
# 记录请求完成
|
||
logger.info(
|
||
"请求完成",
|
||
status_code=response.status_code,
|
||
process_time=process_time,
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
process_time = time.time() - start_time
|
||
|
||
# 记录请求失败
|
||
logger.error(
|
||
"请求失败",
|
||
error=traceback.format_exc(),
|
||
process_time=process_time,
|
||
exc_info=True,
|
||
)
|
||
raise
|
||
|
||
|
||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||
"""安全头中间件"""
|
||
|
||
async def dispatch(
|
||
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||
) -> Response:
|
||
response = await call_next(request)
|
||
|
||
# 添加安全头
|
||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||
response.headers["X-Frame-Options"] = "DENY"
|
||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||
response.headers["Strict-Transport-Security"] = (
|
||
"max-age=31536000; includeSubDomains"
|
||
)
|
||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||
|
||
return response
|
||
|
||
|
||
def setup_cors_middleware(app: FastAPI):
|
||
"""设置CORS中间件"""
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=settings.CORS_ORIGINS,
|
||
allow_credentials=True,
|
||
allow_methods=settings.CORS_METHODS,
|
||
allow_headers=settings.CORS_HEADERS,
|
||
)
|
||
|
||
|
||
def setup_custom_middleware(app: FastAPI):
|
||
"""设置自定义中间件"""
|
||
# 安全头中间件
|
||
app.add_middleware(SecurityHeadersMiddleware)
|
||
|
||
# 请求日志中间件(集成了OpenTelemetry和结构化日志)
|
||
app.add_middleware(RequestLoggingMiddleware)
|
||
|
||
|
||
class MetricsMiddleware(BaseHTTPMiddleware):
|
||
"""专注于 Metrics 收集的中间件"""
|
||
|
||
def __init__(self, app):
|
||
super().__init__(app)
|
||
self.settings = get_settings()
|
||
self.service_name = self.settings.OTEL_SERVICE_NAME
|
||
|
||
# 初始化 meter
|
||
self.meter = metrics.get_meter(__name__)
|
||
|
||
# 初始化 metrics
|
||
self._init_metrics()
|
||
|
||
def _init_metrics(self):
|
||
"""初始化基础 metrics"""
|
||
if not self.settings.OTEL_METRICS_ENABLED:
|
||
return
|
||
|
||
# HTTP 请求总数
|
||
self.request_total = self.meter.create_counter(
|
||
name="http_requests_total",
|
||
description="Total HTTP requests",
|
||
unit="1",
|
||
)
|
||
|
||
# HTTP 请求持续时间
|
||
self.request_duration = self.meter.create_histogram(
|
||
name="http_request_duration_seconds",
|
||
description="HTTP request duration in seconds",
|
||
unit="s",
|
||
)
|
||
|
||
# 当前活跃请求数
|
||
self.active_requests = self.meter.create_up_down_counter(
|
||
name="http_active_requests",
|
||
description="Currently active HTTP requests",
|
||
unit="1",
|
||
)
|
||
|
||
# 业务错误计数
|
||
self.error_total = self.meter.create_counter(
|
||
name="http_errors_total",
|
||
description="Total HTTP errors",
|
||
unit="1",
|
||
)
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
"""处理请求并收集 metrics"""
|
||
if not self.settings.OTEL_ENABLED or not self.settings.OTEL_METRICS_ENABLED:
|
||
return await call_next(request)
|
||
|
||
# 增加活跃请求计数
|
||
if hasattr(self, "active_requests"):
|
||
self.active_requests.add(1, {"service": self.service_name})
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 执行请求
|
||
response = await call_next(request)
|
||
|
||
# 记录成功 metrics
|
||
self._record_metrics(
|
||
request=request,
|
||
status_code=response.status_code,
|
||
duration=time.time() - start_time,
|
||
error_type=None,
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
# 记录错误 metrics
|
||
self._record_metrics(
|
||
request=request,
|
||
status_code=500,
|
||
duration=time.time() - start_time,
|
||
error_type=type(e).__name__,
|
||
)
|
||
raise
|
||
|
||
finally:
|
||
# 减少活跃请求计数
|
||
if hasattr(self, "active_requests"):
|
||
self.active_requests.add(-1, {"service": self.service_name})
|
||
|
||
def _record_metrics(
|
||
self,
|
||
request: Request,
|
||
status_code: int,
|
||
duration: float,
|
||
error_type: Optional[str] = None,
|
||
):
|
||
"""记录 metrics"""
|
||
# 基础标签
|
||
labels = {
|
||
"method": request.method,
|
||
"status_code": str(status_code),
|
||
"service": self.service_name,
|
||
}
|
||
|
||
# 记录请求总数
|
||
if hasattr(self, "request_total"):
|
||
self.request_total.add(1, labels)
|
||
|
||
# 记录请求持续时间
|
||
if hasattr(self, "request_duration"):
|
||
self.request_duration.record(duration, labels)
|
||
|
||
# 记录错误(如果有)
|
||
if error_type and hasattr(self, "error_total"):
|
||
error_labels = {**labels, "error_type": error_type}
|
||
self.error_total.add(1, error_labels)
|
||
|
||
|
||
class BusinessMetricsCollector:
|
||
"""业务指标收集器 - 收集业务相关的 metrics"""
|
||
|
||
def __init__(self):
|
||
self.settings = get_settings()
|
||
self.meter = metrics.get_meter(__name__)
|
||
self.service_name = self.settings.OTEL_SERVICE_NAME
|
||
|
||
# 初始化业务 metrics
|
||
self._init_business_metrics()
|
||
|
||
def _init_business_metrics(self):
|
||
"""初始化业务 metrics"""
|
||
if not self.settings.OTEL_METRICS_ENABLED:
|
||
return
|
||
|
||
# 礼品卡操作指标
|
||
self.gift_card_operations = self.meter.create_counter(
|
||
name="gift_card_operations_total",
|
||
description="Total gift card operations",
|
||
unit="1",
|
||
)
|
||
|
||
# 订单操作指标
|
||
self.order_operations = self.meter.create_counter(
|
||
name="order_operations_total",
|
||
description="Total order operations",
|
||
unit="1",
|
||
)
|
||
|
||
# 爬虫操作指标
|
||
self.crawler_operations = self.meter.create_counter(
|
||
name="crawler_operations_total",
|
||
description="Total crawler operations",
|
||
unit="1",
|
||
)
|
||
|
||
# 数据库连接数
|
||
self.db_connections = self.meter.create_up_down_counter(
|
||
name="database_connections_active",
|
||
description="Active database connections",
|
||
unit="1",
|
||
)
|
||
|
||
# Redis 操作指标
|
||
self.redis_operations = self.meter.create_counter(
|
||
name="redis_operations_total",
|
||
description="Total Redis operations",
|
||
unit="1",
|
||
)
|
||
|
||
# 业务错误指标
|
||
self.business_errors = self.meter.create_counter(
|
||
name="business_errors_total",
|
||
description="Total business logic errors",
|
||
unit="1",
|
||
)
|
||
|
||
def record_gift_card_operation(
|
||
self, operation: str, status: str = "success", **labels
|
||
):
|
||
"""记录礼品卡操作"""
|
||
if hasattr(self, "gift_card_operations"):
|
||
self.gift_card_operations.add(
|
||
1,
|
||
{
|
||
"operation": operation,
|
||
"status": status,
|
||
"service": self.service_name,
|
||
**labels,
|
||
},
|
||
)
|
||
|
||
def record_order_operation(self, operation: str, status: str = "success", **labels):
|
||
"""记录订单操作"""
|
||
if hasattr(self, "order_operations"):
|
||
self.order_operations.add(
|
||
1,
|
||
{
|
||
"operation": operation,
|
||
"status": status,
|
||
"service": self.service_name,
|
||
**labels,
|
||
},
|
||
)
|
||
|
||
def record_crawler_operation(
|
||
self, crawler_type: str, status: str = "success", **labels
|
||
):
|
||
"""记录爬虫操作"""
|
||
if hasattr(self, "crawler_operations"):
|
||
self.crawler_operations.add(
|
||
1,
|
||
{
|
||
"crawler_type": crawler_type,
|
||
"status": status,
|
||
"service": self.service_name,
|
||
**labels,
|
||
},
|
||
)
|
||
|
||
def update_db_connections(self, delta: int):
|
||
"""更新数据库连接数"""
|
||
if hasattr(self, "db_connections"):
|
||
self.db_connections.add(delta, {"service": self.service_name})
|
||
|
||
def record_redis_operation(self, operation: str, status: str = "success", **labels):
|
||
"""记录 Redis 操作"""
|
||
if hasattr(self, "redis_operations"):
|
||
self.redis_operations.add(
|
||
1,
|
||
{
|
||
"operation": operation,
|
||
"status": status,
|
||
"service": self.service_name,
|
||
**labels,
|
||
},
|
||
)
|
||
|
||
def record_business_error(
|
||
self, error_type: str, operation: str = "unknown", **labels
|
||
):
|
||
"""记录业务错误"""
|
||
if hasattr(self, "business_errors"):
|
||
self.business_errors.add(
|
||
1,
|
||
{
|
||
"error_type": error_type,
|
||
"operation": operation,
|
||
"service": self.service_name,
|
||
**labels,
|
||
},
|
||
)
|
||
|
||
|
||
# 全局实例
|
||
_business_metrics_collector: Optional[BusinessMetricsCollector] = None
|
||
|
||
|
||
def get_business_metrics_collector() -> BusinessMetricsCollector:
|
||
"""获取业务指标收集器实例(单例模式)"""
|
||
global _business_metrics_collector
|
||
if _business_metrics_collector is None:
|
||
_business_metrics_collector = BusinessMetricsCollector()
|
||
return _business_metrics_collector
|
||
|
||
|
||
def add_api_logging_middleware(app) -> None:
|
||
"""添加API日志记录中间件(兼容性函数)"""
|
||
# 此函数为向后兼容而保留,实际功能已集成到RequestLoggingMiddleware中
|
||
pass
|
||
|
||
|
||
def add_metrics_middleware(app):
|
||
"""添加 metrics 中间件到 FastAPI 应用"""
|
||
settings = get_settings()
|
||
|
||
if not settings.OTEL_ENABLED or not settings.OTEL_METRICS_ENABLED:
|
||
logger.info("⚠️ OpenTelemetry Metrics 未启用,跳过中间件设置")
|
||
return
|
||
|
||
app.add_middleware(MetricsMiddleware)
|
||
logger.info("✅ Metrics 中间件已添加")
|