build: 更新 Python 版本并添加项目配置

- 将 Python 版本从 3.13.3 降级到3.12.9
- 新增 .gitignore 文件,忽略日志和编译缓存等文件
- 删除空的 __init__.py 文件
- 更新 app/core/config.py 中的配置项
This commit is contained in:
danial
2025-06-02 18:18:07 +08:00
parent 5c869f86c6
commit 16b3e946e7
28 changed files with 911 additions and 469 deletions

6
.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
/logs/*
# python忽略 pyc
/app/__pycache__/*
/**/**/*.pyc
/.idea/
/.vscode/

View File

@@ -1 +1 @@
python 3.13.3
python 3.12.9

View File

@@ -1,4 +1,4 @@
FROM python:3.13-slim
FROM python:3.12-slim
WORKDIR /app

View File

@@ -1 +0,0 @@

View File

@@ -1 +0,0 @@

View File

@@ -1,22 +1,35 @@
import traceback
from fastapi import APIRouter, HTTPException, Body, Query, Depends, BackgroundTasks
from fastapi import APIRouter, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from app.models.spider import HeePayQueryRequest, ResponseData, SpiderTaskRequest, SpiderTaskResponse, SpiderInfoResponse, SpiderTaskStatusResponse
from app.models.spider import (
HeePayQueryRequest,
ResponseData,
SpiderTaskRequest,
SpiderTaskResponse,
SpiderInfoResponse,
SpiderTaskStatusResponse,
)
from typing import List, Dict, Any, Optional
import logging
from app.spiders.heepay import HeePaySpider
from app.spiders.heepay import (
HeePaySpider,
CardQueryException,
CardNotFoundError,
CaptchaError,
NetworkError,
)
from app.spiders.manager import spider_manager
import asyncio
router = APIRouter()
logger = logging.getLogger("kami_spider")
@router.get("/health", tags=["Health"])
def health_check():
"""健康检查接口"""
return {"status": "ok"}
# 爬虫管理接口
@router.get("/spiders", response_model=List[SpiderInfoResponse], tags=["Spiders"])
async def list_spiders():
@@ -24,109 +37,132 @@ async def list_spiders():
# 确保已发现所有爬虫
if not spider_manager.spider_classes:
spider_manager.discover_spiders()
spiders = []
for name, spider_class in spider_manager.spider_classes.items():
try:
doc = spider_class.__doc__ or ""
spiders.append({
"name": name,
"description": doc.strip(),
"parameters": {} # 这里可以添加更多参数信息
})
spiders.append(
{
"name": name,
"description": doc.strip(),
"parameters": {}, # 这里可以添加更多参数信息
}
)
except Exception as e:
logger.error(f"获取爬虫 {name} 信息失败: {e}")
return spiders
@router.post("/spider/task", response_model=SpiderTaskResponse, tags=["Spiders"])
async def submit_spider_task(task: SpiderTaskRequest):
"""提交爬虫任务"""
# 确保已发现所有爬虫
if not spider_manager.spider_classes:
spider_manager.discover_spiders()
# 检查爬虫是否存在
if task.spider_name not in spider_manager.spider_classes:
raise HTTPException(status_code=404, detail=f"爬虫 {task.spider_name} 不存在")
# 启动爬虫
try:
task_id = await spider_manager.start_spider(
task.spider_name,
**task.params
)
task_id = await spider_manager.start_spider(task.spider_name, **task.params)
if not task_id:
raise HTTPException(status_code=500, detail="创建爬虫任务失败")
return SpiderTaskResponse(
task_id=task_id,
status="submitted",
message=f"爬虫任务 {task_id} 已提交"
task_id=task_id, status="submitted", message=f"爬虫任务 {task_id} 已提交"
)
except Exception as e:
logger.error(f"提交爬虫任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"提交爬虫任务失败: {str(e)}")
@router.get("/spider/task/{task_id}", response_model=SpiderTaskStatusResponse, tags=["Spiders"])
@router.get(
"/spider/task/{task_id}", response_model=SpiderTaskStatusResponse, tags=["Spiders"]
)
async def get_spider_task_status(task_id: str):
"""获取爬虫任务状态"""
status = spider_manager.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
return status
@router.post("/spider/task/{task_id}/pause", tags=["Spiders"])
async def pause_spider_task(task_id: str):
"""暂停爬虫任务"""
success = spider_manager.pause_task(task_id)
if not success:
raise HTTPException(status_code=400, detail=f"无法暂停任务 {task_id}")
return {"status": "success", "message": f"任务 {task_id} 已暂停"}
@router.post("/spider/task/{task_id}/resume", tags=["Spiders"])
async def resume_spider_task(task_id: str):
"""恢复爬虫任务"""
success = spider_manager.resume_task(task_id)
if not success:
raise HTTPException(status_code=400, detail=f"无法恢复任务 {task_id}")
return {"status": "success", "message": f"任务 {task_id} 已恢复"}
@router.post("/spider/task/{task_id}/stop", tags=["Spiders"])
async def stop_spider_task(task_id: str):
"""停止爬虫任务"""
success = spider_manager.stop_task(task_id)
if not success:
raise HTTPException(status_code=400, detail=f"无法停止任务 {task_id}")
return {"status": "success", "message": f"任务 {task_id} 已停止"}
@router.get("/spider/tasks", tags=["Spiders"])
async def list_spider_tasks():
"""获取所有爬虫任务"""
tasks = {}
for task_id, task in spider_manager.get_all_tasks().items():
tasks[task_id] = task.to_dict()
return tasks
@router.post("/spider/heepay/query", tags=["骏卡"], summary="查卡", response_model=ResponseData)
def get_heepay_search(data: HeePayQueryRequest) -> JSONResponse:
return tasks
@router.post(
"/spider/heepay/query", tags=["骏卡"], summary="查卡", response_model=ResponseData
)
def get_hee_pay_search(data: HeePayQueryRequest) -> JSONResponse:
try:
transaction = HeePaySpider({}).start(data.cardNo, data.cardSecret)
spider = HeePaySpider({})
transaction = spider.start(data.card_no, data.card_password)
response_data = ResponseData(
code=1,
msg="成功",
data=transaction,
)
return JSONResponse(content=jsonable_encoder(response_data, by_alias=False)) # 👈 关键在这里
return JSONResponse(content=jsonable_encoder(response_data, by_alias=False))
except CardNotFoundError as e:
logger.warning(f"骏卡号不存在: {str(e)}")
error_response = ResponseData(code=-1, msg=f"{str(e)}")
return JSONResponse(content=jsonable_encoder(error_response, by_alias=False))
except CaptchaError as e:
logger.warning(f"验证码识别失败: {str(e)}")
error_response = ResponseData(code=0, msg=f"验证码识别失败,请重试: {str(e)}")
return JSONResponse(content=jsonable_encoder(error_response, by_alias=False))
except NetworkError as e:
logger.error(f"网络请求异常: {str(e)}")
error_response = ResponseData(code=0, msg=f"网络请求异常: {str(e)}")
return JSONResponse(content=jsonable_encoder(error_response, by_alias=False))
except CardQueryException as e:
logger.error(f"骏卡查询异常: {str(e)}")
error_response = ResponseData(code=0, msg=f"查询异常: {str(e)}")
return JSONResponse(content=jsonable_encoder(error_response, by_alias=False))
except Exception as e:
print(traceback.format_exc())
return JSONResponse(content=ResponseData(
code=0,
msg=f"失败: {str(e)}"
))
logger.error(f"骏卡查询失败: {str(e)}", exc_info=True)
error_response = ResponseData(code=0, msg=f"未知错误: {str(e)}")
return JSONResponse(content=jsonable_encoder(error_response, by_alias=False))

View File

@@ -1 +0,0 @@

View File

@@ -6,6 +6,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class LogConfig(BaseModel):
"""日志配置"""
console_enabled: bool = Field(True, description="是否启用控制台日志")
file_enabled: bool = Field(True, description="是否启用文件日志")
file_path: str = Field("logs/kami_spider.log", description="日志文件路径")
@@ -14,7 +15,7 @@ class LogConfig(BaseModel):
use_color: bool = Field(True, description="是否在控制台使用彩色日志")
encoding: str = Field("utf-8", description="日志文件编码")
# 日志轮转配置
max_bytes: int = Field(10*1024*1024, description="单个日志文件最大大小(字节)")
max_bytes: int = Field(10 * 1024 * 1024, description="单个日志文件最大大小(字节)")
backup_count: int = Field(10, description="保留的备份日志文件数量")
when: str = Field("midnight", description="时间轮转间隔: S/M/H/D/midnight")
interval: int = Field(1, description="轮转间隔数")
@@ -22,13 +23,17 @@ class LogConfig(BaseModel):
class OTelConfig(BaseModel):
"""OpenTelemetry 配置"""
enabled: bool = Field(False, description="是否启用 OpenTelemetry")
endpoint: str = Field("http://localhost:4318", description="OpenTelemetry 收集器 HTTP 终端")
endpoint: str = Field(
"http://localhost:4318", description="OpenTelemetry 收集器 HTTP 终端"
)
service_name: str = Field("kami-spider", description="服务名称")
class DatabaseConfig(BaseModel):
"""数据库配置"""
enabled: bool = Field(False, description="是否启用数据库")
host: str = Field("localhost", description="数据库主机")
port: int = Field(3306, description="数据库端口")
@@ -41,56 +46,64 @@ class DatabaseConfig(BaseModel):
class RedisConfig(BaseModel):
"""Redis 配置"""
enabled: bool = Field(False, description="是否启用 Redis")
host: str = Field("localhost", description="Redis 主机")
port: int = Field(6379, description="Redis 端口")
password: Optional[str] = Field(None, description="Redis 密码")
db: int = Field(0, description="Redis 数据库")
use_sentinel: bool = Field(False, description="是否使用 Redis Sentinel")
sentinel_hosts: List[str] = Field(default_factory=list, description="Sentinel 主机列表")
sentinel_hosts: List[str] = Field(
default_factory=list, description="Sentinel 主机列表"
)
sentinel_master: str = Field("mymaster", description="Sentinel 主节点名称")
class Settings(BaseSettings):
"""应用配置"""
# 应用基本信息
app_name: str = Field("Kami Spider", description="应用名称")
debug: bool = Field(False, description="是否为调试模式")
environment: str = Field("production", description="环境名称")
# 服务配置
host: str = Field("0.0.0.0", description="监听主机")
port: int = Field(8000, description="监听端口")
workers: int = Field(4, description="工作进程数")
# 模块配置
log: LogConfig = LogConfig()
otel: OTelConfig = OTelConfig()
database: DatabaseConfig = DatabaseConfig()
redis: RedisConfig = RedisConfig()
# 分布式配置
distributed_config_enabled: bool = Field(False, description="是否启用分布式配置")
distributed_config_type: Optional[str] = Field(None, description="分布式配置类型: etcd, consul, nacos")
distributed_config_endpoint: Optional[str] = Field(None, description="分布式配置终端")
distributed_config_type: Optional[str] = Field(
None, description="分布式配置类型: etcd, consul, nacos"
)
distributed_config_endpoint: Optional[str] = Field(
None, description="分布式配置终端"
)
# 可选项
proxy_pool_enabled: bool = Field(False, description="是否启用代理池")
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
env_nested_delimiter='__',
env_file=".env",
env_file_encoding="utf-8",
env_nested_delimiter="__",
case_sensitive=False,
)
@field_validator('environment')
@field_validator("environment")
def validate_environment(cls, v: str) -> str:
allowed = ['development', 'testing', 'production']
allowed = ["development", "testing", "production"]
if v not in allowed:
raise ValueError(f"环境必须是以下之一: {', '.join(allowed)}")
return v
def get_settings_dict(self) -> Dict[str, Any]:
"""获取配置字典"""
return self.model_dump()
@@ -101,4 +114,4 @@ def get_settings() -> Settings:
return Settings()
settings = get_settings()
settings = get_settings()

View File

@@ -10,22 +10,23 @@ from app.core.config import settings
LOG_FILE = settings.log.file_path
ENCODING = settings.log.encoding
class JsonFormatter(logging.Formatter):
"""JSON格式化器解决编码问题"""
def format(self, record):
# 创建基础日志记录
log_record = {
'timestamp': datetime.utcnow().isoformat(),
'message': record.getMessage(),
'level': record.levelname,
'logger': record.name,
"timestamp": datetime.utcnow().isoformat(),
"message": record.getMessage(),
"level": record.levelname,
"logger": record.name,
}
# 添加异常信息
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
log_record["exception"] = self.formatException(record.exc_info)
# 转换为JSON字符串
return json.dumps(log_record, ensure_ascii=False)
@@ -35,52 +36,54 @@ def setup_logging():
logger = logging.getLogger("kami_spider")
logger.setLevel(getattr(logging, settings.log.level))
logger.handlers = [] # 清除所有已有的handler
# 确保日志目录存在
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
# 控制台日志
if settings.log.console_enabled:
console_handler = logging.StreamHandler(sys.stdout)
# 如果启用颜色
if settings.log.use_color:
color_formatter = colorlog.ColoredFormatter(
fmt='%(log_color)s%(levelname)s%(reset)s - %(message)s',
fmt="%(log_color)s%(levelname)s%(reset)s - %(message)s",
log_colors={
'DEBUG': 'cyan',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'red,bg_white',
}
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
)
console_handler.setFormatter(color_formatter)
else:
console_handler.setFormatter(JsonFormatter())
logger.addHandler(console_handler)
# 文件日志 - 大小轮转
if settings.log.file_enabled:
# 大小轮转处理器
size_handler = RotatingFileHandler(
LOG_FILE,
maxBytes=settings.log.max_bytes,
LOG_FILE,
maxBytes=settings.log.max_bytes,
backupCount=settings.log.backup_count,
encoding=ENCODING
encoding=ENCODING,
)
size_handler.setFormatter(JsonFormatter())
logger.addHandler(size_handler)
# 时间轮转处理器
time_log_file = f"{os.path.splitext(LOG_FILE)[0]}_daily{os.path.splitext(LOG_FILE)[1]}"
time_log_file = (
f"{os.path.splitext(LOG_FILE)[0]}_daily{os.path.splitext(LOG_FILE)[1]}"
)
time_handler = TimedRotatingFileHandler(
time_log_file,
when=settings.log.when,
interval=settings.log.interval,
backupCount=settings.log.backup_count,
encoding=ENCODING
encoding=ENCODING,
)
time_handler.setFormatter(JsonFormatter())
logger.addHandler(time_handler)
@@ -90,7 +93,9 @@ def setup_logging():
try:
from opentelemetry.sdk._logs import LoggingHandler
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.exporter.otlp.proto.http._log_exporter import (
OTLPLogExporter,
)
from opentelemetry.sdk._logs import LoggerProvider
from opentelemetry import _logs
@@ -99,12 +104,14 @@ def setup_logging():
processor = BatchLogRecordProcessor(exporter)
provider.add_log_record_processor(processor)
_logs.set_logger_provider(provider)
otel_handler = LoggingHandler(level=getattr(logging, settings.log.level), logger_provider=provider)
otel_handler = LoggingHandler(
level=getattr(logging, settings.log.level), logger_provider=provider
)
logger.addHandler(otel_handler)
logger.info(f"OpenTelemetry 日志导出已启用,终端: {settings.otel.endpoint}")
except ImportError:
logger.warning("OTel 日志导出未启用,缺少相关依赖。")
else:
logger.info("OpenTelemetry 日志导出已禁用。")
return logger
return logger

View File

@@ -6,24 +6,25 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
import logging
from app.core.config import settings
def setup_tracer(app):
"""设置 OpenTelemetry 追踪,基于配置可选启用"""
logger = logging.getLogger("kami_spider")
if not settings.otel.enabled:
logger.info("OpenTelemetry 链路追踪已禁用。")
return
try:
provider = TracerProvider()
trace.set_tracer_provider(provider)
exporter = OTLPSpanExporter(endpoint=settings.otel.endpoint)
span_processor = BatchSpanProcessor(exporter)
provider.add_span_processor(span_processor)
FastAPIInstrumentor.instrument_app(app, tracer_provider=provider)
logger.info(f"OpenTelemetry 链路追踪已启用,终端: {settings.otel.endpoint}")
except Exception as e:
logger.warning(f"OTel 链路追踪初始化失败: {e}")
logger.warning(f"OTel 链路追踪初始化失败: {e}")

View File

@@ -1 +0,0 @@

View File

@@ -6,14 +6,15 @@ from app.core.config import settings
logger = logging.getLogger("kami_spider")
pool = None
async def init_mysql():
"""初始化 MySQL 连接池"""
global pool
if not settings.database.enabled:
logger.info("MySQL 连接已禁用。")
return
try:
pool = await aiomysql.create_pool(
host=settings.database.host,
@@ -23,13 +24,16 @@ async def init_mysql():
db=settings.database.database,
minsize=settings.database.min_connections,
maxsize=settings.database.max_connections,
autocommit=True
autocommit=True,
)
logger.info(
f"MySQL 连接池已初始化: {settings.database.host}:{settings.database.port}/{settings.database.database}"
)
logger.info(f"MySQL 连接池已初始化: {settings.database.host}:{settings.database.port}/{settings.database.database}")
except Exception as e:
logger.error(f"MySQL 连接池初始化失败: {e}")
raise
async def close_mysql():
"""关闭 MySQL 连接池"""
global pool
@@ -38,37 +42,40 @@ async def close_mysql():
await pool.wait_closed()
logger.info("MySQL 连接池已关闭")
async def execute(query: str, *args, **kwargs) -> int:
"""执行 SQL 语句,返回影响的行数"""
if not pool:
logger.warning("MySQL 连接池未初始化")
return 0
async with pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(query, *args, **kwargs)
return cur.rowcount
async def fetch_one(query: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
"""执行查询并获取单行结果"""
if not pool:
logger.warning("MySQL 连接池未初始化")
return None
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(query, *args, **kwargs)
result = await cur.fetchone()
return result
async def fetch_all(query: str, *args, **kwargs) -> List[Dict[str, Any]]:
"""执行查询并获取所有结果"""
if not pool:
logger.warning("MySQL 连接池未初始化")
return []
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(query, *args, **kwargs)
result = await cur.fetchall()
return result
return result

View File

@@ -7,34 +7,33 @@ logger = logging.getLogger("kami_spider")
redis_client: Optional[redis.Redis] = None
sentinel_client: Optional[redis.sentinel.Sentinel] = None
async def init_redis():
"""初始化 Redis 客户端"""
global redis_client, sentinel_client
if not settings.redis.enabled:
logger.info("Redis 连接已禁用。")
return
try:
if settings.redis.use_sentinel:
# 使用 Sentinel 进行高可用
if not settings.redis.sentinel_hosts:
logger.error("Redis Sentinel 配置错误: 未提供 sentinel_hosts")
return
sentinel_hosts = [(host.split(':')[0], int(host.split(':')[1]))
for host in settings.redis.sentinel_hosts]
sentinel_hosts = [
(host.split(":")[0], int(host.split(":")[1]))
for host in settings.redis.sentinel_hosts
]
sentinel_client = redis.sentinel.Sentinel(
sentinel_hosts,
password=settings.redis.password,
socket_timeout=1.0
sentinel_hosts, password=settings.redis.password, socket_timeout=1.0
)
redis_client = sentinel_client.master_for(
settings.redis.sentinel_master,
db=settings.redis.db,
socket_timeout=1.0
settings.redis.sentinel_master, db=settings.redis.db, socket_timeout=1.0
)
logger.info(f"Redis Sentinel 连接已初始化: {settings.redis.sentinel_hosts}")
else:
@@ -44,10 +43,12 @@ async def init_redis():
port=settings.redis.port,
password=settings.redis.password,
db=settings.redis.db,
decode_responses=True
decode_responses=True,
)
logger.info(f"Redis 连接已初始化: {settings.redis.host}:{settings.redis.port}")
logger.info(
f"Redis 连接已初始化: {settings.redis.host}:{settings.redis.port}"
)
# 测试连接
await redis_client.ping()
except Exception as e:
@@ -55,51 +56,57 @@ async def init_redis():
redis_client = None
sentinel_client = None
async def close_redis():
"""关闭 Redis 连接"""
global redis_client, sentinel_client
if redis_client:
await redis_client.close()
redis_client = None
logger.info("Redis 连接已关闭")
if sentinel_client:
await sentinel_client.close()
sentinel_client = None
logger.info("Redis Sentinel 连接已关闭")
async def get(key: str) -> Any:
"""获取键值"""
if not redis_client:
logger.warning("Redis 客户端未初始化")
return None
return await redis_client.get(key)
async def set(key: str, value: Union[str, bytes, int, float],
expire: Optional[int] = None) -> bool:
async def set(
key: str, value: Union[str, bytes, int, float], expire: Optional[int] = None
) -> bool:
"""设置键值"""
if not redis_client:
logger.warning("Redis 客户端未初始化")
return False
if expire:
return await redis_client.set(key, value, ex=expire)
return await redis_client.set(key, value)
async def delete(key: str) -> int:
"""删除键"""
if not redis_client:
logger.warning("Redis 客户端未初始化")
return 0
return await redis_client.delete(key)
async def exists(key: str) -> bool:
"""检查键是否存在"""
if not redis_client:
logger.warning("Redis 客户端未初始化")
return False
return await redis_client.exists(key) > 0
return await redis_client.exists(key) > 0

View File

@@ -11,6 +11,7 @@ import logging
# 导入可选数据库模块
try:
from app.db import mysql, redis_client
has_db_modules = True
except ImportError:
has_db_modules = False
@@ -20,9 +21,7 @@ from app.spiders.manager import spider_manager
from app.spiders.base.spider import SpiderStatus
app = FastAPI(
title=settings.app_name,
description="分布式爬虫 Web 框架",
debug=settings.debug
title=settings.app_name, description="分布式爬虫 Web 框架", debug=settings.debug
)
# 初始化日志
@@ -36,57 +35,58 @@ app.include_router(router)
shutdown_event = asyncio.Event()
@app.on_event("startup")
async def startup():
"""应用启动时执行"""
logger.info(f"{settings.app_name} 服务启动中,环境:{settings.environment}")
# 初始化数据库连接(如果启用)
if has_db_modules:
if settings.database.enabled:
await mysql.init_mysql()
if settings.redis.enabled:
await redis_client.init_redis()
# 发现并注册爬虫
spider_manager.discover_spiders()
logger.info(f"已发现 {len(spider_manager.spider_classes)} 个爬虫")
@app.on_event("shutdown")
async def on_shutdown():
"""应用关闭时执行"""
logger.info("应用正在优雅关闭,等待所有请求完成...")
shutdown_event.set()
# 停止所有运行中的爬虫
for task_id, spider in list(spider_manager.running_spiders.items()):
if spider.status in [SpiderStatus.RUNNING, SpiderStatus.PAUSED]:
logger.info(f"停止爬虫任务: {task_id}")
spider.stop()
# 关闭数据库连接
if has_db_modules:
if settings.database.enabled:
await mysql.close_mysql()
if settings.redis.enabled:
await redis_client.close_redis()
logger.info("应用已安全关闭。")
# 支持优雅关闭Gunicorn/Uvicorn 信号处理)
def handle_signal(*_):
logger.info("收到终止信号,触发优雅关闭...")
asyncio.create_task(app.router.shutdown())
signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGINT, handle_signal)
if __name__ == "__main__":
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug
)
"app.main:app", host=settings.host, port=settings.port, reload=settings.debug
)

View File

@@ -1 +0,0 @@

View File

@@ -1,45 +1,41 @@
from typing import Union, Optional
from typing import Union, Optional, Dict, Any
from pydantic import BaseModel, Field, model_validator
class CardUseRecord(BaseModel):
"""骏卡使用记录模型"""
transaction_type: str = Field(default=None, description="交易类型")
used_j_points: str = Field(default=None, description="使用 j 点")
usage_note: str = Field(default=None, description="使用说明")
usage_time: str = Field(default=None, description="使用时间")
result: str = Field(default=None, description="结果")
from pydantic import BaseModel, Field, validator
class CardTransaction(BaseModel):
# 原卡信息字段
cardNo: str = Field(default=None, description="卡号")
cardSecret: str = Field(default=None, description="卡密")
card_number: str = Field(default=None, alias='骏卡卡号', description="骏卡卡号")
card_type: str = Field(default=None, alias='骏卡卡种', description="骏卡卡种")
card_status: str = Field(default=None, alias='骏卡状态', description="骏卡状态")
j_points: str = Field(default=None, alias='J点面值', description="J点面值")
success_time: str = Field(default=None, alias='成功时间', description="成功时间")
"""骏卡交易模型"""
# 充值记录字段
product_name: str = Field(default=None, alias='充值产品名称', description="充值产品名称")
transaction_type: str = Field(default=None, alias='交易类型', description="交易类型")
j_points_used: str = Field(default=None, alias='使用 j 点', description="使用 j 点")
usage_note: str = Field(default=None, alias='使用说明', description="使用说明")
usage_time: str = Field(default=None, alias='使用时间', description="使用时间")
result: str = Field(default=None, alias='结果', description="结果")
# 卡片基本信息
card_number: str = Field(default=None, description="骏卡卡号")
card_password: str = Field(default=None, description="骏卡卡密")
card_type: str = Field(default=None, description="骏卡卡种")
card_status: str = Field(default=None, description="骏卡状态")
j_points: str = Field(default=None, description="J点面值")
locked_j_points: str = Field(default=None, description="锁定J点")
available_j_points: str = Field(default=None, description="可用J点")
success_time: str = Field(default=None, description="成功时间")
points: Union[str, float] = Field(default="异常", description="卡面值 J点面值 / 100失败则为“异常”")
use_records: list[CardUseRecord] = Field(default=[], description="使用记录")
class Config:
allow_population_by_field_name = True # 允许通过别名或字段名初始化
arbitrary_types_allowed = True # 允许任意类型
by_alias = False # 关键设置
@validator("points", always=True)
def convert_j_points(cls, v, values):
raw = values.get("j_points")
try:
num = int(raw)
return num / 100
except (ValueError, TypeError):
return "异常"
def map_alias_to_field(model: CardTransaction, data: dict) -> dict:
alias_map = {field.alias: name for name, field in model.__fields__.items()}
return {alias_map.get(k, k): v for k, v in data.items()}
class ResponseData(BaseModel):
"""API响应数据模型"""
code: int = Field(..., description="响应状态码1-成功0-失败")
msg: str = Field(..., description="响应信息: success-成功, 其他为错误信息")
data: Optional[CardTransaction] = Field(default=None, description="卡交易详情")

View File

@@ -1,26 +1,34 @@
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any, List, Union
class SpiderTaskRequest(BaseModel):
"""爬虫任务请求"""
spider_name: str = Field(..., description="爬虫名称")
params: Dict[str, Any] = Field(default_factory=dict, description="爬虫参数")
callback: Optional[str] = Field(None, description="回调地址")
class SpiderTaskResponse(BaseModel):
"""爬虫任务响应"""
task_id: str = Field(..., description="任务 ID")
status: str = Field(..., description="任务状态")
message: Optional[str] = Field(None, description="消息")
class SpiderInfoResponse(BaseModel):
"""爬虫信息响应"""
name: str = Field(..., description="爬虫名称")
description: str = Field("", description="爬虫描述")
parameters: Dict[str, Any] = Field(default_factory=dict, description="参数信息")
class SpiderStatistics(BaseModel):
"""爬虫统计信息"""
start_time: Optional[str] = Field(None, description="开始时间")
end_time: Optional[str] = Field(None, description="结束时间")
duration: Optional[float] = Field(None, description="持续时间(秒)")
@@ -31,8 +39,10 @@ class SpiderStatistics(BaseModel):
item_count: int = Field(0, description="处理数据项数")
error_messages: List[str] = Field(default_factory=list, description="错误消息")
class SpiderTaskStatusResponse(BaseModel):
"""爬虫任务状态响应"""
task_id: str = Field(..., description="任务 ID")
spider_name: str = Field(..., description="爬虫名称")
status: str = Field(..., description="任务状态")
@@ -43,7 +53,7 @@ class SpiderTaskStatusResponse(BaseModel):
spider_status: Optional[str] = Field(None, description="爬虫状态")
stats: Optional[Dict[str, Any]] = Field(None, description="爬虫统计")
error: Optional[str] = Field(None, description="错误信息")
params: Dict[str, Any] = Field(default_factory=dict, description="爬虫参数")
params: Dict[str, Any] = Field(default_factory=dict, description="爬虫参数")
class ResponseData(BaseModel):
@@ -53,5 +63,5 @@ class ResponseData(BaseModel):
class HeePayQueryRequest(BaseModel):
cardNo: str = Field(..., description="卡号")
cardSecret: str = Field(..., description="卡密")
card_no: str = Field(..., description="卡号")
card_password: str = Field(..., description="卡密")

View File

@@ -1,4 +1,4 @@
"""
爬虫模块包,包含各种爬虫实现。
每个爬虫都是独立的,可以单独运行,也可以通过调度器统一管理。
"""
"""

View File

@@ -1,4 +1,4 @@
"""
爬虫基类包,定义爬虫的基础接口和抽象类。
这些基类将被具体的爬虫实现所继承。
"""
"""

View File

@@ -1,6 +1,7 @@
"""
基于HTTP的爬虫实现提供HTTP请求和响应处理。
"""
import aiohttp
import logging
import asyncio
@@ -10,17 +11,27 @@ from app.core.config import settings
import json
from yarl import URL
T = TypeVar('T')
T = TypeVar("T")
logger = logging.getLogger("kami_spider")
class HttpSpiderRequest(SpiderRequest):
"""HTTP爬虫请求"""
def __init__(self, url: str, method: str = "GET", headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Any]] = None, data: Optional[Any] = None,
json_data: Optional[Dict[str, Any]] = None, cookies: Optional[Dict[str, str]] = None,
timeout: int = 30, retry: int = 3, metadata: Optional[Dict[str, Any]] = None,
proxy: Optional[str] = None):
def __init__(
self,
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
json_data: Optional[Dict[str, Any]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 30,
retry: int = 3,
metadata: Optional[Dict[str, Any]] = None,
proxy: Optional[str] = None,
):
super().__init__(url, method, headers, params, data, timeout, retry, metadata)
self.json_data = json_data
self.cookies = cookies or {}
@@ -29,14 +40,21 @@ class HttpSpiderRequest(SpiderRequest):
class HttpSpiderResponse(SpiderResponse):
"""HTTP爬虫响应"""
def __init__(self, request: HttpSpiderRequest, status_code: int = 200,
headers: Optional[Dict[str, str]] = None, body: Any = None,
text: Optional[str] = None, cookies: Optional[Dict[str, str]] = None,
error: Optional[Exception] = None):
def __init__(
self,
request: HttpSpiderRequest,
status_code: int = 200,
headers: Optional[Dict[str, str]] = None,
body: Any = None,
text: Optional[str] = None,
cookies: Optional[Dict[str, str]] = None,
error: Optional[Exception] = None,
):
super().__init__(request, status_code, headers, body, error)
self.text = text
self.cookies = cookies or {}
def json(self) -> Any:
"""解析JSON响应"""
if not self.text:
@@ -50,13 +68,19 @@ class HttpSpiderResponse(SpiderResponse):
class HttpSpider(Spider[HttpSpiderRequest, HttpSpiderResponse, T], Generic[T]):
"""HTTP爬虫基类"""
def __init__(self, name: str, concurrent_requests: int = 5,
user_agent: Optional[str] = None, cookies: Optional[Dict[str, str]] = None,
proxy: Optional[str] = None, use_proxy_pool: bool = False):
def __init__(
self,
name: str,
concurrent_requests: int = 5,
user_agent: Optional[str] = None,
cookies: Optional[Dict[str, str]] = None,
proxy: Optional[str] = None,
use_proxy_pool: bool = False,
):
"""
初始化HTTP爬虫
Args:
name: 爬虫名称
concurrent_requests: 并发请求数
@@ -71,20 +95,20 @@ class HttpSpider(Spider[HttpSpiderRequest, HttpSpiderResponse, T], Generic[T]):
self.proxy = proxy
self.use_proxy_pool = use_proxy_pool and settings.proxy_pool_enabled
self.session: Optional[aiohttp.ClientSession] = None
async def create_session(self) -> aiohttp.ClientSession:
"""创建HTTP会话"""
if self.session is None or self.session.closed:
headers = {"User-Agent": self.user_agent}
self.session = aiohttp.ClientSession(headers=headers, cookies=self.cookies)
return self.session
async def close_session(self) -> None:
"""关闭HTTP会话"""
if self.session and not self.session.closed:
await self.session.close()
logger.debug(f"爬虫 {self.name} 的HTTP会话已关闭")
async def get_proxy(self) -> Optional[str]:
"""获取代理地址"""
if self.use_proxy_pool:
@@ -92,7 +116,7 @@ class HttpSpider(Spider[HttpSpiderRequest, HttpSpiderResponse, T], Generic[T]):
# TODO: 实现代理池获取逻辑
return None
return self.proxy
async def make_request(self, url: str, **kwargs) -> HttpSpiderRequest:
"""创建HTTP请求"""
method = kwargs.get("method", "GET")
@@ -105,7 +129,7 @@ class HttpSpider(Spider[HttpSpiderRequest, HttpSpiderResponse, T], Generic[T]):
retry = kwargs.get("retry", 3)
metadata = kwargs.get("metadata", {})
proxy = kwargs.get("proxy", await self.get_proxy())
return HttpSpiderRequest(
url=url,
method=method,
@@ -117,14 +141,14 @@ class HttpSpider(Spider[HttpSpiderRequest, HttpSpiderResponse, T], Generic[T]):
timeout=timeout,
retry=retry,
metadata=metadata,
proxy=proxy
proxy=proxy,
)
async def fetch(self, request: HttpSpiderRequest) -> HttpSpiderResponse:
"""执行HTTP请求并获取响应"""
session = await self.create_session()
retry_count = 0
while retry_count <= request.retry:
try:
async with session.request(
@@ -136,50 +160,50 @@ class HttpSpider(Spider[HttpSpiderRequest, HttpSpiderResponse, T], Generic[T]):
headers=request.headers,
cookies=request.cookies,
proxy=request.proxy,
timeout=aiohttp.ClientTimeout(total=request.timeout)
timeout=aiohttp.ClientTimeout(total=request.timeout),
) as resp:
text = await resp.text()
return HttpSpiderResponse(
request=request,
status_code=resp.status,
headers=dict(resp.headers),
body=await resp.read(),
text=text,
cookies=dict(resp.cookies)
cookies=dict(resp.cookies),
)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
retry_count += 1
self.stats.add_retry()
if retry_count <= request.retry:
wait_time = 2 ** retry_count # 指数退避
logger.warning(f"请求失败,将在 {wait_time} 秒后重试 ({retry_count}/{request.retry}): {request.url}")
wait_time = 2**retry_count # 指数退避
logger.warning(
f"请求失败,将在 {wait_time} 秒后重试 ({retry_count}/{request.retry}): {request.url}"
)
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败,已达到最大重试次数 ({request.retry}): {request.url}")
return HttpSpiderResponse(
request=request,
status_code=0,
error=e
logger.error(
f"请求失败,已达到最大重试次数 ({request.retry}): {request.url}"
)
return HttpSpiderResponse(request=request, status_code=0, error=e)
async def start(self) -> None:
"""启动爬虫"""
try:
# 创建HTTP会话
await self.create_session()
# 启动爬虫
await super().start()
finally:
# 关闭会话
await self.close_session()
@classmethod
async def create_and_run(cls, *args, **kwargs) -> "HttpSpider":
"""创建爬虫实例并运行"""
spider = cls(*args, **kwargs)
await spider.start()
return spider
return spider

View File

@@ -1,6 +1,7 @@
"""
爬虫基类模块,定义爬虫的基本接口和行为。
"""
import abc
import logging
import asyncio
@@ -9,27 +10,37 @@ from datetime import datetime
from app.core.config import settings
# 定义类型变量用于泛型
T = TypeVar('T')
RequestT = TypeVar('RequestT')
ResponseT = TypeVar('ResponseT')
T = TypeVar("T")
RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")
logger = logging.getLogger("kami_spider")
class SpiderStatus:
"""爬虫状态枚举"""
IDLE = "idle" # 空闲
RUNNING = "running" # 运行中
PAUSED = "paused" # 已暂停
IDLE = "idle" # 空闲
RUNNING = "running" # 运行中
PAUSED = "paused" # 已暂停
FINISHED = "finished" # 已完成
ERROR = "error" # 出错
ERROR = "error" # 出错
class SpiderRequest:
"""爬虫请求基类"""
def __init__(self, url: str, method: str = "GET", headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Any]] = None, data: Optional[Any] = None,
timeout: int = 30, retry: int = 3, metadata: Optional[Dict[str, Any]] = None):
def __init__(
self,
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
timeout: int = 30,
retry: int = 3,
metadata: Optional[Dict[str, Any]] = None,
):
self.url = url
self.method = method
self.headers = headers or {}
@@ -43,9 +54,15 @@ class SpiderRequest:
class SpiderResponse:
"""爬虫响应基类"""
def __init__(self, request: SpiderRequest, status_code: int = 200,
headers: Optional[Dict[str, str]] = None, body: Any = None,
error: Optional[Exception] = None):
def __init__(
self,
request: SpiderRequest,
status_code: int = 200,
headers: Optional[Dict[str, str]] = None,
body: Any = None,
error: Optional[Exception] = None,
):
self.request = request
self.status_code = status_code
self.headers = headers or {}
@@ -56,7 +73,10 @@ class SpiderResponse:
class SpiderItem(Generic[T]):
"""爬虫数据项基类"""
def __init__(self, data: T, source_url: str, metadata: Optional[Dict[str, Any]] = None):
def __init__(
self, data: T, source_url: str, metadata: Optional[Dict[str, Any]] = None
):
self.data = data
self.source_url = source_url
self.metadata = metadata or {}
@@ -65,6 +85,7 @@ class SpiderItem(Generic[T]):
class SpiderStats:
"""爬虫统计数据"""
def __init__(self):
self.start_time: Optional[datetime] = None
self.end_time: Optional[datetime] = None
@@ -74,44 +95,44 @@ class SpiderStats:
self.retry_count: int = 0
self.item_count: int = 0
self.error_messages: List[str] = []
def start(self):
"""开始记录统计"""
self.start_time = datetime.now()
def finish(self):
"""结束记录统计"""
self.end_time = datetime.now()
def add_request(self):
"""增加请求计数"""
self.request_count += 1
def add_success(self):
"""增加成功计数"""
self.success_count += 1
def add_failure(self, error_msg: str = None):
"""增加失败计数"""
self.failure_count += 1
if error_msg:
self.error_messages.append(error_msg)
def add_retry(self):
"""增加重试计数"""
self.retry_count += 1
def add_item(self):
"""增加数据项计数"""
self.item_count += 1
def get_duration(self) -> Optional[float]:
"""获取持续时间(秒)"""
if self.start_time is None:
return None
end = self.end_time or datetime.now()
return (end - self.start_time).total_seconds()
def get_stats_dict(self) -> Dict[str, Any]:
"""获取统计数据字典"""
return {
@@ -123,17 +144,17 @@ class SpiderStats:
"failure_count": self.failure_count,
"retry_count": self.retry_count,
"item_count": self.item_count,
"error_messages": self.error_messages
"error_messages": self.error_messages,
}
class Spider(abc.ABC, Generic[RequestT, ResponseT, T]):
"""爬虫抽象基类"""
def __init__(self, name: str, concurrent_requests: int = 5):
"""
初始化爬虫
Args:
name: 爬虫名称
concurrent_requests: 并发请求数
@@ -146,102 +167,102 @@ class Spider(abc.ABC, Generic[RequestT, ResponseT, T]):
self._seen_urls: Set[str] = set()
self.semaphore = asyncio.Semaphore(concurrent_requests)
logger.info(f"爬虫 {self.name} 已初始化,并发请求数: {concurrent_requests}")
@abc.abstractmethod
async def parse(self, response: ResponseT) -> Union[List[RequestT], List[T], None]:
"""
解析响应数据,并返回新的请求或数据项
Args:
response: 响应对象
Returns:
新的请求列表或数据项列表
"""
pass
@abc.abstractmethod
async def process_item(self, item: T) -> Optional[T]:
"""
处理数据项
Args:
item: 数据项
Returns:
处理后的数据项如果返回None表示丢弃该项
"""
pass
@abc.abstractmethod
async def make_request(self, url: str, **kwargs) -> RequestT:
"""
创建请求对象
Args:
url: 请求URL
**kwargs: 其他参数
Returns:
请求对象
"""
pass
@abc.abstractmethod
async def fetch(self, request: RequestT) -> ResponseT:
"""
执行请求并获取响应
Args:
request: 请求对象
Returns:
响应对象
"""
pass
async def start(self) -> None:
"""启动爬虫"""
if not self.start_urls:
logger.warning(f"爬虫 {self.name} 没有起始URL无法启动")
return
logger.info(f"爬虫 {self.name} 开始运行")
self.status = SpiderStatus.RUNNING
self.stats.start()
# 创建初始请求
requests = [await self.make_request(url) for url in self.start_urls]
# 启动请求处理
tasks = [self.process_request(request) for request in requests]
await asyncio.gather(*tasks)
self.status = SpiderStatus.FINISHED
self.stats.finish()
logger.info(f"爬虫 {self.name} 已完成,统计: {self.stats.get_stats_dict()}")
async def process_request(self, request: RequestT) -> None:
"""处理单个请求"""
if isinstance(request, SpiderRequest) and request.url in self._seen_urls:
logger.debug(f"URL已处理过跳过: {request.url}")
return
async with self.semaphore:
if self.status != SpiderStatus.RUNNING:
return
self.stats.add_request()
if isinstance(request, SpiderRequest):
self._seen_urls.add(request.url)
try:
response = await self.fetch(request)
self.stats.add_success()
result = await self.parse(response)
if result:
for item in result:
if isinstance(item, SpiderRequest):
@@ -252,34 +273,34 @@ class Spider(abc.ABC, Generic[RequestT, ResponseT, T]):
processed_item = await self.process_item(item)
if processed_item is not None:
self.stats.add_item()
except Exception as e:
self.stats.add_failure(str(e))
logger.error(f"处理请求失败: {e}", exc_info=True)
def pause(self) -> None:
"""暂停爬虫"""
if self.status == SpiderStatus.RUNNING:
self.status = SpiderStatus.PAUSED
logger.info(f"爬虫 {self.name} 已暂停")
def resume(self) -> None:
"""恢复爬虫"""
if self.status == SpiderStatus.PAUSED:
self.status = SpiderStatus.RUNNING
logger.info(f"爬虫 {self.name} 已恢复")
def stop(self) -> None:
"""停止爬虫"""
self.status = SpiderStatus.FINISHED
self.stats.finish()
logger.info(f"爬虫 {self.name} 已停止")
def get_stats(self) -> Dict[str, Any]:
"""获取爬虫统计数据"""
return {
"name": self.name,
"status": self.status,
"concurrent_requests": self.concurrent_requests,
"stats": self.stats.get_stats_dict()
}
"stats": self.stats.get_stats_dict(),
}

View File

@@ -1 +1 @@
"""爬虫示例包,包含各种爬虫实现示例。"""
"""爬虫示例包,包含各种爬虫实现示例。"""

View File

@@ -1,6 +1,7 @@
"""
简单爬虫示例,展示如何实现一个基本的爬虫。
"""
from typing import List, Optional, Union, Dict, Any
import logging
from bs4 import BeautifulSoup
@@ -11,13 +12,18 @@ logger = logging.getLogger("kami_spider")
class SimpleSpider(HttpSpider[Dict[str, Any]]):
"""简单爬虫示例,抓取网页标题和链接"""
def __init__(self, name: str = "SimpleSpider", concurrent_requests: int = 3,
max_depth: int = 2, allowed_domains: Optional[List[str]] = None,
**kwargs):
def __init__(
self,
name: str = "SimpleSpider",
concurrent_requests: int = 3,
max_depth: int = 2,
allowed_domains: Optional[List[str]] = None,
**kwargs,
):
"""
初始化简单爬虫
Args:
name: 爬虫名称
concurrent_requests: 并发请求数
@@ -27,67 +33,68 @@ class SimpleSpider(HttpSpider[Dict[str, Any]]):
super().__init__(name, concurrent_requests, **kwargs)
self.max_depth = max_depth
self.allowed_domains = allowed_domains or []
# 设置起始URL
self.start_urls = [
"https://example.com/",
"https://www.python.org/"
]
self.start_urls = ["https://example.com/", "https://www.python.org/"]
def is_allowed_domain(self, url: str) -> bool:
"""检查URL是否属于允许的域名"""
if not self.allowed_domains:
return True
try:
from urllib.parse import urlparse
domain = urlparse(url).netloc
return any(domain.endswith(d) for d in self.allowed_domains)
except Exception as e:
logger.error(f"解析URL时出错: {e}")
return False
async def parse(self, response: HttpSpiderResponse) -> Union[List[HttpSpiderRequest], List[Dict[str, Any]], None]:
async def parse(
self, response: HttpSpiderResponse
) -> Union[List[HttpSpiderRequest], List[Dict[str, Any]], None]:
"""解析响应"""
if response.error or response.status_code != 200:
logger.error(f"请求失败: {response.error or response.status_code}")
return None
# 获取当前深度
current_depth = response.request.metadata.get("depth", 1)
logger.info(f"正在解析: {response.request.url} (深度: {current_depth})")
# 解析HTML
try:
soup = BeautifulSoup(response.text, "html.parser")
# 提取标题
title = soup.title.string.strip() if soup.title else "无标题"
# 创建数据项
item = {
"url": response.request.url,
"title": title,
"depth": current_depth,
"status": response.status_code
"status": response.status_code,
}
results = [item]
# 如果未达到最大深度,则提取链接并创建新的请求
if current_depth < self.max_depth:
new_requests = []
# 提取所有链接
links = soup.find_all("a", href=True)
for link in links[:5]: # 限制每页最多处理5个链接
href = link["href"]
# 将相对URL转为绝对URL
if href.startswith("/"):
from urllib.parse import urljoin
href = urljoin(response.request.url, href)
# 只处理http/https链接
if href.startswith(("http://", "https://")):
# 检查域名是否允许
@@ -95,20 +102,23 @@ class SimpleSpider(HttpSpider[Dict[str, Any]]):
# 创建新请求
req = await self.make_request(
href,
metadata={"depth": current_depth + 1, "parent": response.request.url}
metadata={
"depth": current_depth + 1,
"parent": response.request.url,
},
)
new_requests.append(req)
results.extend(new_requests)
return results
except Exception as e:
logger.error(f"解析HTML时出错: {e}", exc_info=True)
return None
async def process_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""处理数据项"""
# 简单地记录数据项
logger.info(f"处理数据项: {item['url']} - {item['title']}")
return item
return item

View File

@@ -1,172 +1,470 @@
"""
骏卡查询爬虫模块,负责查询骏卡卡号和卡密的相关信息
设计模式:
1. 单例模式 - OCR识别器
2. 策略模式 - 验证码处理
3. 模板方法模式 - 爬虫流程
"""
import base64
import hashlib
import json
import re
from typing import Dict
from enum import Enum
from typing import Dict, Optional, Tuple, Union, List
import ddddocr
import requests
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from lxml import html
from lxml.html import HtmlElement
import ddddocr
from app.core.logging_config import setup_logging
from app.models.heepay import CardTransaction, map_alias_to_field
from bs4 import BeautifulSoup, NavigableString
from retry import retry
from app.core.logging_config import setup_logging
from app.models.heepay import CardTransaction, CardUseRecord
# 设置日志记录器
logger = setup_logging()
class HeePaySpider:
__ocr = ddddocr.DdddOcr()
def __init__(self, proxies=None):
class CardQueryException(Exception):
"""骏卡查询异常基类"""
pass
class CaptchaError(CardQueryException):
"""验证码错误异常"""
pass
class CardNotFoundError(CardQueryException):
"""卡号不存在异常"""
pass
class NetworkError(CardQueryException):
"""网络异常"""
pass
class CardStatus(Enum):
"""骏卡状态枚举"""
NORMAL = "正常"
USED = "已使用"
class OcrRecognizer:
"""OCR识别器单例类"""
_instance = None
def __init__(self):
self._ocr = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(OcrRecognizer, cls).__new__(cls)
cls._instance._ocr = ddddocr.DdddOcr(show_ad=False)
return cls._instance
def recognize(self, image_bytes: bytes) -> str:
"""识别图片中的文字"""
if not hasattr(self, "_ocr") or self._ocr is None:
self._ocr = ddddocr.DdddOcr(show_ad=False)
result = self._ocr.classification(image_bytes)
return str(result)
class HeePaySpider:
"""骏卡查询爬虫"""
# 基础URL配置
BASE_URL = "https://junka.heepay.com"
INDEX_URL = f"{BASE_URL}/Bill/TradeSearch.aspx"
CAPTCHA_URL = f"{BASE_URL}/Modules/RandomFzImage.aspx"
# AES加密配置
AES_KEY = "eda0ac8d3b5e9446"
AES_IV = "0123456789ABCDEF"
SIGN_PREFIX = "fK8kpKe7YcutStELHL1JuJw=="
# XPath选择器
CARD_RESULT_SELECTOR = (
"//div[@id='ctl00_MainContent_divUCardSearchResult']//tbody//tr"
)
TRADE_RESULT_SELECTOR = "//div[@id='ctl00_MainContent_panUCardTrade']//tbody//tr"
def __init__(self, proxies: Optional[Dict[str, str]] = None):
"""
初始化骏卡查询爬虫
Args:
proxies: 代理设置,格式为 {"http": "http://proxy.example.com", "https": "https://proxy.example.com"}
"""
self.session = requests.Session()
self.proxies = proxies
self.ocr = OcrRecognizer()
def get_index(self):
headers = {
# 移动端请求头
self.mobile_headers = {
"Host": "junka.heepay.com",
"Pragma": "no-cache",
"Cache-Control": "no-cache",
"sec-ch-ua": "\"\"",
"sec-ch-ua": '""',
"sec-ch-ua-mobile": "?1",
"sec-ch-ua-platform": "\"\"",
"sec-ch-ua-platform": '""',
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Linux; Android 5.0; SM-N9100 Build/LRX21V) > AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 > Chrome/37.0.0.0 Mobile Safari/537.36 > MicroMessenger/6.0.2.56_r958800.520 NetType/WIFI",
"User-Agent": "Mozilla/5.0 (Linux; Android 5.0; SM-N9100 Build/LRX21V) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/37.0.0.0 Mobile Safari/537.36 MicroMessenger/6.0.2.56_r958800.520 NetType/WIFI",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Sec-Fetch-Site": "none",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-User": "?1",
"Sec-Fetch-Dest": "document",
"Accept-Language": "zh-CN,zh;q=0.9"
"Accept-Language": "zh-CN,zh;q=0.9",
}
url = "https://junka.heepay.com/Bill/TradeSearch.aspx"
response = self.session.get(url, headers=headers, proxies=self.proxies)
viewstate = re.findall('id="__VIEWSTATE" value="(.*?)"', response.text)[0]
generator = re.findall('id="__VIEWSTATEGENERATOR" value="(.*?)"', response.text)[0]
AffineX = re.findall('id="ctl00_MainContent_hidAffineX" value="(.*?)"', response.text)[0]
AffineY = re.findall('id="ctl00_MainContent_hidAffineY" value="(.*?)"', response.text)[0]
return viewstate, generator, AffineX, AffineY
def get_imgcode(self):
headers = {
# 图片请求头
self.image_headers = {
"Accept": "image/avif,image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Pragma": "no-cache",
"Referer": "https://junka.heepay.com/Bill/TradeSearch.aspx",
"Referer": self.INDEX_URL,
"Sec-Fetch-Dest": "image",
"Sec-Fetch-Mode": "no-cors",
"Sec-Fetch-Site": "same-origin",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
"sec-ch-ua": "\"Not(A:Brand\";v=\"99\", \"Google Chrome\";v=\"133\", \"Chromium\";v=\"133\"",
"sec-ch-ua": '"Not(A:Brand";v="99", "Google Chrome";v="133", "Chromium";v="133"',
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": "\"macOS\""
"sec-ch-ua-platform": '"macOS"',
}
url = "https://junka.heepay.com/Modules/RandomFzImage.aspx"
response = self.session.get(url, headers=headers, proxies=self.proxies)
code = self.__ocr.classification(response.content)
return code
def queryCard(self, cardNo: str, cardSecret: str, viewstate: str, generator: str, AffineX: str, AffineY: str,
code: str) -> CardTransaction:
headers = {
"Host": "junka.heepay.com",
"Pragma": "no-cache",
"Cache-Control": "no-cache",
"sec-ch-ua": "\"\"",
"sec-ch-ua-mobile": "?1",
"sec-ch-ua-platform": "\"\"",
"Origin": "https://junka.heepay.com",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Linux; Android 5.0; SM-N9100 Build/LRX21V) > AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 > Chrome/37.0.0.0 Mobile Safari/537.36 > MicroMessenger/6.0.2.56_r958800.520 NetType/WIFI",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Sec-Fetch-Site": "same-origin",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-User": "?1",
"Sec-Fetch-Dest": "document",
"Referer": "https://junka.heepay.com/Bill/TradeSearch.aspx",
"Accept-Language": "zh-CN,zh;q=0.9"
}
url = "https://junka.heepay.com/Bill/TradeSearch.aspx"
def get_index(self) -> Tuple[str, str, str, str]:
"""
获取首页并提取必要参数
Returns:
Tuple[str, str, str, str]: (viewstate, generator, affine_x, affine_y)
Raises:
NetworkError: 网络请求异常
"""
try:
response = self.session.get(
self.INDEX_URL,
headers=self.mobile_headers,
proxies=self.proxies,
timeout=30,
)
response.raise_for_status()
# 提取表单参数
viewstate = self._extract_pattern(
response.text, 'id="__VIEWSTATE" value="(.*?)"'
)
generator = self._extract_pattern(
response.text, 'id="__VIEWSTATEGENERATOR" value="(.*?)"'
)
affine_x = self._extract_pattern(
response.text, 'id="ctl00_MainContent_hidAffineX" value="(.*?)"'
)
affine_y = self._extract_pattern(
response.text, 'id="ctl00_MainContent_hidAffineY" value="(.*?)"'
)
logger.debug("成功获取首页参数")
return viewstate, generator, affine_x, affine_y
except requests.RequestException as e:
logger.error(f"获取首页失败: {str(e)}")
raise NetworkError(f"网络请求异常: {str(e)}")
except (IndexError, ValueError) as e:
logger.error(f"解析首页参数失败: {str(e)}")
raise CardQueryException(f"解析首页参数失败: {str(e)}")
def get_captcha(self) -> str:
"""
获取并识别验证码
Returns:
str: 识别的验证码文本
Raises:
NetworkError: 网络请求异常
CaptchaError: 验证码识别失败
"""
try:
response = self.session.get(
self.CAPTCHA_URL,
headers=self.image_headers,
proxies=self.proxies,
timeout=30,
)
response.raise_for_status()
# 使用OCR识别验证码
code = self.ocr.recognize(response.content)
if not code or len(code) < 4:
logger.warning(f"验证码识别可能不准确: {code}")
else:
logger.debug(f"验证码识别成功: {code}")
return code
except requests.RequestException as e:
logger.error(f"获取验证码失败: {str(e)}")
raise NetworkError(f"网络请求异常: {str(e)}")
except Exception as e:
logger.error(f"验证码识别失败: {str(e)}")
raise CaptchaError(f"验证码识别失败: {str(e)}")
def query_card(
self,
card_no: str,
card_password: str,
viewstate: str,
generator: str,
affine_x: str,
affine_y: str,
code: str,
) -> CardTransaction:
"""
查询骏卡信息
Args:
card_no: 骏卡卡号
card_password: 骏卡卡密
viewstate: 页面状态参数
generator: 页面状态生成器参数
affine_x: 页面坐标X参数
affine_y: 页面坐标Y参数
code: 验证码
Returns:
CardTransaction: 卡片交易信息对象
Raises:
CaptchaError: 验证码错误
CardNotFoundError: 卡号不存在
CardSecretError: 卡密错误
NetworkError: 网络请求异常
CardQueryException: 其他查询异常
"""
# 构建请求数据
data = {
"__VIEWSTATE": viewstate,
"__VIEWSTATEGENERATOR": generator,
"__VIEWSTATEENCRYPTED": "",
"ctl00$Header$txtSearchKey": "",
"ctl00$MainContent$hidAffineX": AffineX,
"ctl00$MainContent$hidAffineY": AffineY,
"ctl00$MainContent$txtCardNo": cardNo,
"ctl00$MainContent$hidAffineX": affine_x,
"ctl00$MainContent$hidAffineY": affine_y,
"ctl00$MainContent$txtCardNo": card_no,
"ctl00$MainContent$hidUCardPassword": "",
"ctl00$MainContent$txtUCardPassword": cardSecret,
"ctl00$MainContent$txtUCardPassword": card_password,
"ctl00$MainContent$txtCardNoExtCode": code, # 验证码
"ctl00$MainContent$btnSearch": ""
"ctl00$MainContent$btnSearch": "",
}
data["sign"] = self.get_sign(json.dumps(data, separators=(",", ":")))
headers["Param"] = self.aes_cbc_encrypt_base64(json.dumps(data, separators=(",", ":")), key="eda0ac8d3b5e9446",
iv="0123456789ABCDEF")
response = self.session.post(url, headers=headers, data=data, proxies=self.proxies)
if "请输入正确的附加码" in response.text:
raise Exception("验证码识别失败")
tree = html.fromstring(response.text)
uCardSearchResult = tree.xpath("//div[@id='ctl00_MainContent_divUCardSearchResult']//tbody//tr")
# 获取查询结果
card_info = self.parse_table_rows(uCardSearchResult)
transaction = CardTransaction(**card_info)
if tree.xpath("//div[@id='ctl00_MainContent_panUCardTrade']"):
panUCardTrade = tree.xpath("//div[@id='ctl00_MainContent_panUCardTrade']//tbody//tr")
card_trade = self.parse_table_rows(panUCardTrade)
mapped_trade = map_alias_to_field(CardTransaction, card_trade)
else:
mapped_trade = {}
if mapped_trade:
new_transaction = transaction.copy(update=mapped_trade)
else:
new_transaction = transaction
new_transaction.cardNo = cardNo
new_transaction.cardSecret = cardSecret
return new_transaction
# 准备请求头
headers = self.mobile_headers.copy()
headers.update(
{
"Origin": self.BASE_URL,
"Sec-Fetch-Site": "same-origin",
"Referer": self.INDEX_URL,
}
)
@retry(tries=3)
def start(self, card_no: str, card_secret: str) -> CardTransaction:
viewstate, generator, AffineX, AffineY = self.get_index()
code = self.get_imgcode()
transaction = self.queryCard(card_no, card_secret, viewstate, generator, AffineX, AffineY, code)
# 计算签名并加密参数
data["sign"] = self._get_sign(json.dumps(data, separators=(",", ":")))
headers["Param"] = self._aes_encrypt(
json.dumps(data, separators=(",", ":")), key=self.AES_KEY, iv=self.AES_IV
)
try:
# 发送请求
response = self.session.post(
self.INDEX_URL,
headers=headers,
data=data,
proxies=self.proxies,
timeout=30,
)
response.raise_for_status()
alert = re.search(r"alert\('(.*?)'\)", response.text)
if alert:
if "请输入正确的附加码" in alert.group(1):
raise CaptchaError("验证码识别失败")
raise CardNotFoundError(alert.group(1))
soup = BeautifulSoup(response.text, "lxml")
result = CardTransaction(card_password=card_password, card_number=card_no)
query_list = []
query_result = soup.find("div", id="ctl00_MainContent_divUCardSearchResult")
if query_result:
query_result = query_result.find("table", id="table4")
if query_result:
query_list = self._parse_table_rows(query_result)
if query_list:
result.card_type = query_list[0].get("骏卡卡种")
result.card_status = query_list[0].get("骏卡状态")
result.j_points = query_list[0].get("J点面值")
result.locked_j_points = query_list[0].get("锁定J点")
result.available_j_points = query_list[0].get("可用J点")
query_result = soup.find("div", id="ctl00_MainContent_panUCardTrade")
if query_result:
query_result = query_result.find("table", id="table4")
used_list = []
if query_result:
used_list = self._parse_table_rows(query_result)
for used_data in used_list:
result.use_records.append(
CardUseRecord(
transaction_type=used_data.get("交易类型"),
usage_time=used_data.get("使用时间"),
used_j_points=used_data.get("使用 j 点"),
usage_note=used_data.get("使用说明"),
result=used_data.get("结果"),
)
)
return result
except (CaptchaError, CardNotFoundError):
# 重新抛出已处理的异常
raise
except requests.RequestException as e:
logger.error(f"查询卡片网络请求失败: {str(e)}")
raise NetworkError(f"网络请求异常: {str(e)}")
except Exception as e:
logger.error(f"查询卡片异常: {str(e)}")
raise CardQueryException(f"查询卡片异常: {str(e)}")
@retry(
tries=5,
delay=1,
backoff=2,
exceptions=(CaptchaError, NetworkError, CardQueryException),
)
def start(self, card_no: str, card_password: str) -> CardTransaction:
"""
启动骏卡查询流程
Args:
card_no: 骏卡卡号
card_password: 骏卡卡密
Returns:
CardTransaction: 卡片交易信息对象
Raises:
CardQueryException: 查询异常
"""
logger.info(f"开始查询骏卡: {card_no},卡密:{card_password}")
# 获取首页参数
viewstate, generator, affine_x, affine_y = self.get_index()
# 获取并识别验证码
code = self.get_captcha()
# 查询卡片信息
transaction = self.query_card(
card_no, card_password, viewstate, generator, affine_x, affine_y, code
)
logger.info(
f"完成骏卡查询: {card_no},卡密:{card_password},查询结果:{transaction}"
)
return transaction
def get_sign(self, data: str) -> str:
data = f"fK8kpKe7YcutStELHL1JuJw=={data}"
def _get_sign(self, data: str) -> str:
"""
计算签名
Args:
data: 要签名的数据
Returns:
str: MD5签名结果
"""
data = f"{self.SIGN_PREFIX}{data}"
m = hashlib.md5()
m.update(data.encode("utf-8"))
return m.hexdigest()
def aes_cbc_encrypt_base64(self, plaintext: str, key: str, iv: str) -> str:
def _aes_encrypt(self, plaintext: str, key: str, iv: str) -> str:
"""
加密
:param plaintext: 待加密字符串(明文)
:param key: AES 密钥16/24/32 字节)
:param iv: 初始化向量16 字节)
:return: 加密后 Base64 编码的字符串
AES-CBC加密并Base64编码
Args:
plaintext: 明文
key: AES密钥
iv: 初始化向量
Returns:
str: Base64编码的加密结果
"""
cipher = AES.new(key.encode("utf-8"), AES.MODE_CBC, iv.encode("utf-8"))
padded = pad(plaintext.encode('utf-8'), AES.block_size)
padded = pad(plaintext.encode("utf-8"), AES.block_size)
encrypted = cipher.encrypt(padded)
return base64.b64encode(encrypted).decode('utf-8')
return base64.b64encode(encrypted).decode("utf-8")
def parse_table_rows(self, tr_list: list[HtmlElement]) -> Dict[str, str]:
if not tr_list or len(tr_list) < 2:
return {}
header_cells = tr_list[0].xpath('.//th') or tr_list[0].xpath('.//td')
headers = [cell.text_content().strip() for cell in header_cells]
# 解析数据行
for tr in tr_list[1:]:
data_cells = tr.xpath('.//td')
if len(data_cells) != len(headers):
continue # 可选:跳过字段数不一致的行
row = {headers[i]: data_cells[i].text_content().strip() for i in range(len(headers))}
return row
return {}
def _parse_table_rows(
self, node_tree: BeautifulSoup | NavigableString
) -> list[dict[str, str]]:
"""
解析表格行
Args:
node_tree: 表格行HTML元素
Returns:
Dict[str, str]: 解析结果字典
"""
result = []
if not node_tree:
return result
header_list = []
if node_tree.select("tr.table_head2 > td"):
header_list = [
ele.get_text(strip=True)
for ele in node_tree.select("tr.table_head2 > td")
]
data_list = node_tree.select("tbody > tr:not(.table_head2)")
if data_list:
for data in data_list:
if len(data.select("td")) != len(header_list):
continue
data_dict = {
header_list[i]: data.select("td")[i].get_text(strip=True)
for i in range(len(header_list))
}
result.append(data_dict)
return result
def _extract_pattern(self, text: str, pattern: str) -> str:
"""
从文本中提取匹配模式的内容
Args:
text: 源文本
pattern: 正则表达式模式
Returns:
str: 匹配的内容
Raises:
ValueError: 未找到匹配内容
"""
matches = re.findall(pattern, text)
if not matches:
raise ValueError(f"未找到匹配的内容: {pattern}")
return matches[0]

View File

@@ -1,6 +1,7 @@
"""
爬虫管理器模块,负责爬虫的注册、创建、启动和管理。
"""
import asyncio
import logging
import inspect
@@ -18,8 +19,13 @@ logger = logging.getLogger("kami_spider")
class SpiderTask:
"""爬虫任务"""
def __init__(self, spider_name: str, task_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None):
def __init__(
self,
spider_name: str,
task_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
):
self.task_id = task_id or str(uuid.uuid4())
self.spider_name = spider_name
self.params = params or {}
@@ -29,7 +35,7 @@ class SpiderTask:
self.status = "pending" # pending, running, completed, failed
self.spider: Optional[Spider] = None
self.error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
@@ -43,133 +49,140 @@ class SpiderTask:
"duration": self.end_time - self.start_time if self.end_time else None,
"spider_status": self.spider.status if self.spider else None,
"stats": self.spider.get_stats() if self.spider else None,
"error": self.error
"error": self.error,
}
class SpiderManager:
"""爬虫管理器"""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SpiderManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not getattr(self, "_initialized", False):
# 存储爬虫类
self.spider_classes: Dict[str, Type[Spider]] = {}
# 存储运行中的爬虫实例
self.running_spiders: Dict[str, Spider] = {}
# 存储爬虫任务
self.tasks: Dict[str, SpiderTask] = {}
# 爬虫包路径
self.spider_packages = ["app.spiders"]
# 事件循环
self.loop = asyncio.get_event_loop()
self._initialized = True
logger.info("爬虫管理器已初始化")
def register_spider(self, spider_class: Type[Spider]) -> None:
"""注册爬虫类"""
if not inspect.isclass(spider_class) or not issubclass(spider_class, Spider):
logger.warning(f"无法注册爬虫: {spider_class}它不是Spider的子类")
return
name = spider_class.__name__
if name in self.spider_classes:
logger.warning(f"爬虫 {name} 已存在,将被覆盖")
self.spider_classes[name] = spider_class
logger.info(f"已注册爬虫: {name}")
def register_spiders_from_module(self, module_name: str) -> None:
"""从模块注册爬虫"""
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, Spider) and obj.__module__ == module_name:
if (
inspect.isclass(obj)
and issubclass(obj, Spider)
and obj.__module__ == module_name
):
self.register_spider(obj)
except ImportError as e:
logger.error(f"导入模块 {module_name} 失败: {e}")
def discover_spiders(self) -> None:
"""发现并注册所有爬虫"""
for package_name in self.spider_packages:
try:
package = importlib.import_module(package_name)
for _, name, is_pkg in pkgutil.iter_modules(package.__path__):
if not is_pkg and not name.startswith('_'):
if not is_pkg and not name.startswith("_"):
module_name = f"{package_name}.{name}"
self.register_spiders_from_module(module_name)
elif is_pkg and not name.startswith('_') and name != 'base':
elif is_pkg and not name.startswith("_") and name != "base":
# 递归检查子包
subpackage_name = f"{package_name}.{name}"
subpackage = importlib.import_module(subpackage_name)
for _, subname, is_subpkg in pkgutil.iter_modules(subpackage.__path__):
if not is_subpkg and not subname.startswith('_'):
for _, subname, is_subpkg in pkgutil.iter_modules(
subpackage.__path__
):
if not is_subpkg and not subname.startswith("_"):
module_name = f"{subpackage_name}.{subname}"
self.register_spiders_from_module(module_name)
except (ImportError, AttributeError) as e:
logger.error(f"发现爬虫包 {package_name} 时出错: {e}")
logger.info(f"已发现 {len(self.spider_classes)} 个爬虫")
def get_spider_class(self, name: str) -> Optional[Type[Spider]]:
"""获取爬虫类"""
return self.spider_classes.get(name)
def get_all_spiders(self) -> Dict[str, Type[Spider]]:
"""获取所有爬虫类"""
return self.spider_classes
async def create_spider(self, name: str, **kwargs) -> Optional[Spider]:
"""创建爬虫实例"""
spider_class = self.get_spider_class(name)
if not spider_class:
logger.error(f"找不到爬虫: {name}")
return None
try:
spider = spider_class(**kwargs)
return spider
except Exception as e:
logger.error(f"创建爬虫 {name} 失败: {e}", exc_info=True)
return None
async def start_spider(self, name: str, **kwargs) -> Optional[str]:
"""启动爬虫返回任务ID"""
task = SpiderTask(name, params=kwargs)
self.tasks[task.task_id] = task
# 创建爬虫实例
spider = await self.create_spider(name, **kwargs)
if not spider:
task.status = "failed"
task.error = f"创建爬虫 {name} 失败"
return None
task.spider = spider
self.running_spiders[task.task_id] = spider
# 使用事件循环运行爬虫
asyncio.create_task(self._run_spider_task(task))
return task.task_id
async def _run_spider_task(self, task: SpiderTask) -> None:
"""运行爬虫任务"""
task.status = "running"
task.start_time = time.time()
try:
await task.spider.start()
task.status = "completed"
@@ -182,53 +195,53 @@ class SpiderManager:
# 从运行中移除
if task.task_id in self.running_spiders:
del self.running_spiders[task.task_id]
def get_task(self, task_id: str) -> Optional[SpiderTask]:
"""获取任务"""
return self.tasks.get(task_id)
def get_all_tasks(self) -> Dict[str, SpiderTask]:
"""获取所有任务"""
return self.tasks
def pause_task(self, task_id: str) -> bool:
"""暂停任务"""
task = self.get_task(task_id)
if not task or not task.spider:
logger.warning(f"任务 {task_id} 不存在或无爬虫实例")
return False
if task.spider.status == SpiderStatus.RUNNING:
task.spider.pause()
return True
return False
def resume_task(self, task_id: str) -> bool:
"""恢复任务"""
task = self.get_task(task_id)
if not task or not task.spider:
logger.warning(f"任务 {task_id} 不存在或无爬虫实例")
return False
if task.spider.status == SpiderStatus.PAUSED:
task.spider.resume()
return True
return False
def stop_task(self, task_id: str) -> bool:
"""停止任务"""
task = self.get_task(task_id)
if not task or not task.spider:
logger.warning(f"任务 {task_id} 不存在或无爬虫实例")
return False
if task.spider.status in [SpiderStatus.RUNNING, SpiderStatus.PAUSED]:
task.spider.stop()
if task_id in self.running_spiders:
del self.running_spiders[task_id]
return True
return False
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""获取任务状态"""
task = self.get_task(task_id)
@@ -238,4 +251,4 @@ class SpiderManager:
# 单例
spider_manager = SpiderManager()
spider_manager = SpiderManager()

View File

@@ -1 +0,0 @@

View File

@@ -1 +1 @@
# 通用工具函数,可在此扩展
# 通用工具函数,可在此扩展

View File

@@ -2,7 +2,7 @@ fastapi==0.115.12
uvicorn[standard]==0.34.2
gunicorn==23.0.0
pydantic==2.8.2
pydantic-settings==2.2.1
pydantic-settings==2.5.2
opentelemetry-api==1.33.1
opentelemetry-sdk==1.33.1
opentelemetry-instrumentation-fastapi==0.54b1
@@ -12,15 +12,14 @@ colorlog==6.8.2
python-dotenv==1.0.1
redis==5.0.3
aiomysql==0.2.0
aiohttp==3.9.5
bs4==0.0.2
aiohttp==3.10.11
beautifulsoup4==4.12.3
lxml==5.2.0
aiofiles==23.2.1
requests==2.32.3
urllib3==2.2.2
yarl==1.11.1
ddddocr==1.0.6
ddddocr~=1.5.6
tenacity==9.1.2
retry==0.9.2
pycryptodome==3.20.0