Files
kami_apple_exchage/backend/app/core/middleware.py
danial 5c486e34d3 docs(项目): 添加项目文档并进行代码调整
- 新增 CODEBUDDY.md、GEMINI.md、GEMINI_CN.md 等项目文档
- 更新 Dockerfile 和其他配置文件
- 优化部分代码结构,如 orders.py、tasks.py 等
- 新增 .dockerignore 文件
2025-09-12 19:38:24 +08:00

402 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI中间件模块
包含请求日志、安全头、CORS、OpenTelemetry等中间件配置
"""
import time
import traceback
import uuid
from typing import Callable, Awaitable, Optional
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from opentelemetry import trace, metrics
from app.core.config import get_settings
from app.core.log import get_logger, LogContext, request_id_var
settings = get_settings()
logger = get_logger(__name__)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""请求日志中间件 - 集成OpenTelemetry和结构化日志"""
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
# 生成请求ID
current_span = trace.get_current_span()
if current_span.is_recording():
# 使用OpenTelemetry trace ID作为请求ID
request_id = format(current_span.get_span_context().trace_id, "032x")
else:
# 如果没有trace生成UUID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
start_time = time.time()
# 使用日志上下文
with LogContext(request_id=request_id):
# 记录请求开始
# logger.info(
# "请求开始",
# method=str(request.method),
# url=str(request.url),
# client_ip=str(request.client.host) if request.client else "unknown",
# user_agent=str(request.headers.get("user-agent", "unknown")),
# )
try:
# 执行请求
response = await call_next(request)
process_time = time.time() - start_time
# 设置响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = f"{process_time:.3f}"
# 记录请求完成
logger.info(
"请求完成",
status_code=response.status_code,
process_time=process_time,
)
return response
except Exception as e:
process_time = time.time() - start_time
# 记录请求失败
logger.error(
"请求失败",
error=traceback.format_exc(),
process_time=process_time,
exc_info=True,
)
raise
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""安全头中间件"""
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
response = await call_next(request)
# 添加安全头
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
return response
def setup_cors_middleware(app: FastAPI):
"""设置CORS中间件"""
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=settings.CORS_METHODS,
allow_headers=settings.CORS_HEADERS,
)
def setup_custom_middleware(app: FastAPI):
"""设置自定义中间件"""
# 安全头中间件
app.add_middleware(SecurityHeadersMiddleware)
# 请求日志中间件集成了OpenTelemetry和结构化日志
app.add_middleware(RequestLoggingMiddleware)
class MetricsMiddleware(BaseHTTPMiddleware):
"""专注于 Metrics 收集的中间件"""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
self.service_name = self.settings.OTEL_SERVICE_NAME
# 初始化 meter
self.meter = metrics.get_meter(__name__)
# 初始化 metrics
self._init_metrics()
def _init_metrics(self):
"""初始化基础 metrics"""
if not self.settings.OTEL_METRICS_ENABLED:
return
# HTTP 请求总数
self.request_total = self.meter.create_counter(
name="http_requests_total",
description="Total HTTP requests",
unit="1",
)
# HTTP 请求持续时间
self.request_duration = self.meter.create_histogram(
name="http_request_duration_seconds",
description="HTTP request duration in seconds",
unit="s",
)
# 当前活跃请求数
self.active_requests = self.meter.create_up_down_counter(
name="http_active_requests",
description="Currently active HTTP requests",
unit="1",
)
# 业务错误计数
self.error_total = self.meter.create_counter(
name="http_errors_total",
description="Total HTTP errors",
unit="1",
)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""处理请求并收集 metrics"""
if not self.settings.OTEL_ENABLED or not self.settings.OTEL_METRICS_ENABLED:
return await call_next(request)
# 增加活跃请求计数
if hasattr(self, "active_requests"):
self.active_requests.add(1, {"service": self.service_name})
start_time = time.time()
try:
# 执行请求
response = await call_next(request)
# 记录成功 metrics
self._record_metrics(
request=request,
status_code=response.status_code,
duration=time.time() - start_time,
error_type=None,
)
return response
except Exception as e:
# 记录错误 metrics
self._record_metrics(
request=request,
status_code=500,
duration=time.time() - start_time,
error_type=type(e).__name__,
)
raise
finally:
# 减少活跃请求计数
if hasattr(self, "active_requests"):
self.active_requests.add(-1, {"service": self.service_name})
def _record_metrics(
self,
request: Request,
status_code: int,
duration: float,
error_type: Optional[str] = None,
):
"""记录 metrics"""
# 基础标签
labels = {
"method": request.method,
"status_code": str(status_code),
"service": self.service_name,
}
# 记录请求总数
if hasattr(self, "request_total"):
self.request_total.add(1, labels)
# 记录请求持续时间
if hasattr(self, "request_duration"):
self.request_duration.record(duration, labels)
# 记录错误(如果有)
if error_type and hasattr(self, "error_total"):
error_labels = {**labels, "error_type": error_type}
self.error_total.add(1, error_labels)
class BusinessMetricsCollector:
"""业务指标收集器 - 收集业务相关的 metrics"""
def __init__(self):
self.settings = get_settings()
self.meter = metrics.get_meter(__name__)
self.service_name = self.settings.OTEL_SERVICE_NAME
# 初始化业务 metrics
self._init_business_metrics()
def _init_business_metrics(self):
"""初始化业务 metrics"""
if not self.settings.OTEL_METRICS_ENABLED:
return
# 礼品卡操作指标
self.gift_card_operations = self.meter.create_counter(
name="gift_card_operations_total",
description="Total gift card operations",
unit="1",
)
# 订单操作指标
self.order_operations = self.meter.create_counter(
name="order_operations_total",
description="Total order operations",
unit="1",
)
# 爬虫操作指标
self.crawler_operations = self.meter.create_counter(
name="crawler_operations_total",
description="Total crawler operations",
unit="1",
)
# 数据库连接数
self.db_connections = self.meter.create_up_down_counter(
name="database_connections_active",
description="Active database connections",
unit="1",
)
# Redis 操作指标
self.redis_operations = self.meter.create_counter(
name="redis_operations_total",
description="Total Redis operations",
unit="1",
)
# 业务错误指标
self.business_errors = self.meter.create_counter(
name="business_errors_total",
description="Total business logic errors",
unit="1",
)
def record_gift_card_operation(
self, operation: str, status: str = "success", **labels
):
"""记录礼品卡操作"""
if hasattr(self, "gift_card_operations"):
self.gift_card_operations.add(
1,
{
"operation": operation,
"status": status,
"service": self.service_name,
**labels,
},
)
def record_order_operation(self, operation: str, status: str = "success", **labels):
"""记录订单操作"""
if hasattr(self, "order_operations"):
self.order_operations.add(
1,
{
"operation": operation,
"status": status,
"service": self.service_name,
**labels,
},
)
def record_crawler_operation(
self, crawler_type: str, status: str = "success", **labels
):
"""记录爬虫操作"""
if hasattr(self, "crawler_operations"):
self.crawler_operations.add(
1,
{
"crawler_type": crawler_type,
"status": status,
"service": self.service_name,
**labels,
},
)
def update_db_connections(self, delta: int):
"""更新数据库连接数"""
if hasattr(self, "db_connections"):
self.db_connections.add(delta, {"service": self.service_name})
def record_redis_operation(self, operation: str, status: str = "success", **labels):
"""记录 Redis 操作"""
if hasattr(self, "redis_operations"):
self.redis_operations.add(
1,
{
"operation": operation,
"status": status,
"service": self.service_name,
**labels,
},
)
def record_business_error(
self, error_type: str, operation: str = "unknown", **labels
):
"""记录业务错误"""
if hasattr(self, "business_errors"):
self.business_errors.add(
1,
{
"error_type": error_type,
"operation": operation,
"service": self.service_name,
**labels,
},
)
# 全局实例
_business_metrics_collector: Optional[BusinessMetricsCollector] = None
def get_business_metrics_collector() -> BusinessMetricsCollector:
"""获取业务指标收集器实例(单例模式)"""
global _business_metrics_collector
if _business_metrics_collector is None:
_business_metrics_collector = BusinessMetricsCollector()
return _business_metrics_collector
def add_api_logging_middleware(app) -> None:
"""添加API日志记录中间件兼容性函数"""
# 此函数为向后兼容而保留实际功能已集成到RequestLoggingMiddleware中
pass
def add_metrics_middleware(app):
"""添加 metrics 中间件到 FastAPI 应用"""
settings = get_settings()
if not settings.OTEL_ENABLED or not settings.OTEL_METRICS_ENABLED:
logger.info("⚠️ OpenTelemetry Metrics 未启用,跳过中间件设置")
return
app.add_middleware(MetricsMiddleware)
logger.info("✅ Metrics 中间件已添加")