diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..9be3b66 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +python 3.13.9 diff --git a/src/cmd/scripts.py b/src/cmd/scripts.py index 01e69eb..17bc857 100644 --- a/src/cmd/scripts.py +++ b/src/cmd/scripts.py @@ -23,9 +23,9 @@ from src.service.itunes import ItunesService def run_redeem_task( - master_order: RechargeQueryModel, - master_node_service: MasterNodeService, - reties=3, + master_order: RechargeQueryModel, + master_node_service: MasterNodeService, + reties=3, ): redis_conn = redis_pool.get_connection() redis_client = redis.Redis(connection_pool=redis_conn) @@ -58,8 +58,8 @@ def run_redeem_task( logger.info(f"查找已存在Cookie信息:{apple_account_schema.account}") # 如果更新过账户密码 if ( - apple_account_schema.account == master_order.account - and apple_account_schema.password != master_order.password + apple_account_schema.account == master_order.account + and apple_account_schema.password != master_order.password ): # 删除登录标识 redis_client.delete(f"apple_account_{master_order.account}") @@ -74,11 +74,13 @@ def run_redeem_task( redis_client.setex( f"apple_account_{master_order.account}", time=timedelta(seconds=30), - value=pickle.dumps(AppleAccountSchema( - account=master_order.account, - password=master_order.password, - status=2, - ).model_dump()), + value=pickle.dumps( + AppleAccountSchema( + account=master_order.account, + password=master_order.password, + status=2, + ).model_dump() + ), ) response_schema = itunes_service.login( AppleAccountModel( @@ -92,12 +94,14 @@ def run_redeem_task( redis_client.setex( f"apple_account_{master_order.account}", time=timedelta(days=1), - value=pickle.dumps(AppleAccountSchema( - account=master_order.account, - password=master_order.password, - status=0, - login_schema=response_schema, - ).model_dump()), + value=pickle.dumps( + AppleAccountSchema( + account=master_order.account, + password=master_order.password, + status=0, + login_schema=response_schema, + ).model_dump() + ), ) master_node_service.update_order_status( ItunesRedeemRequestModel( @@ -117,12 +121,14 @@ def run_redeem_task( redis_client.setex( f"apple_account_{master_order.account}", time=timedelta(minutes=9), - value=pickle.dumps(AppleAccountSchema( - account=master_order.account, - password=master_order.password, - status=1, - login_schema=response_schema, - ).model_dump()) + value=pickle.dumps( + AppleAccountSchema( + account=master_order.account, + password=master_order.password, + status=1, + login_schema=response_schema, + ).model_dump() + ), ) redeem_result = itunes_service.redeem( master_order.cardPass, @@ -141,9 +147,9 @@ def run_redeem_task( redis_client.delete(f"apple_account_{master_order.account}") return run_redeem_task(master_order, master_node_service, reties - 1) # if redeem_result.status == 40: - # logger.warning("充值1分钟限制,1分钟后重试") - # time.sleep(60) - # return run_redeem_task(master_order, master_node_service, reties - 1) + # logger.warning("充值1分钟限制,1分钟后重试") + # time.sleep(60) + # return run_redeem_task(master_order, master_node_service, reties - 1) # 更新兑换状态 master_node_service.update_order_status( ItunesRedeemRequestModel( @@ -212,12 +218,14 @@ def run(): # 10分钟打印一次信息 if datetime.now().second % 30 == 0 and not has_been_console: has_been_console = True - print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\t心跳正常 当前状态:{process.is_alive()}", - flush=True, ) + print( + f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\t心跳正常 当前状态:{process.is_alive()}", + flush=True, + ) if datetime.now().second % 30 != 0 and has_been_console: has_been_console = False if not process.is_alive(): process.close() process = Process(target=run_task, args=(), daemon=True) process.start() - process.join() \ No newline at end of file + process.join() diff --git a/src/database/mysql_db.py b/src/database/mysql_db.py index 4682e15..da1da0f 100644 --- a/src/database/mysql_db.py +++ b/src/database/mysql_db.py @@ -8,5 +8,5 @@ engine = create_engine( echo_pool=True, echo=True, pool_recycle=300, # 每隔 300 秒(5分钟)强制回收连接 - pool_pre_ping=True # 每次从连接池取连接前执行简单 ping 测试 + pool_pre_ping=True, # 每次从连接池取连接前执行简单 ping 测试 ) diff --git a/src/integrations/itunes/api.py b/src/integrations/itunes/api.py index 88902ed..8c40e19 100644 --- a/src/integrations/itunes/api.py +++ b/src/integrations/itunes/api.py @@ -10,7 +10,8 @@ from requests.adapters import HTTPAdapter from src.integrations.itunes.models.login import ( ItunesLoginResponse, ItunesFailLoginPlistData, - ItunesSuccessLoginPlistData, ItunesAccountInfo, + ItunesSuccessLoginPlistData, + ItunesAccountInfo, ) from src.integrations.itunes.models.redeem import ( RedeemFailResponseModel, @@ -24,11 +25,12 @@ from src.service.proxy import ProxyService # 禁用 SSL 警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + class AppleClient: def __init__(self): self.__session = requests.Session() self.__session.verify = False # 禁用 SSL 证书验证 - + # 设置重试策略 retry_strategy = Retry( total=3, # 总重试次数 @@ -41,14 +43,16 @@ class AppleClient: adapter = HTTPAdapter(max_retries=retry_strategy) self.__session.mount("http://", adapter) self.__session.mount("https://", adapter) - + # 设置默认请求头 - self.__session.headers.update({ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36', - 'Accept': '*/*', - 'Accept-Encoding': 'gzip, deflate', - 'Connection': 'keep-alive' - }) + self.__session.headers.update( + { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate", + "Connection": "keep-alive", + } + ) def query_sign_sap_setup(self, signature: LoginSignatureModel, retries=3) -> str: if retries <= 0: @@ -87,7 +91,11 @@ class AppleClient: return response.text def login( - self, sign_map: AuthenticateModel, account_info: ItunesAccountInfo, server_id: str = "", retries: int = 5 + self, + sign_map: AuthenticateModel, + account_info: ItunesAccountInfo, + server_id: str = "", + retries: int = 5, ) -> ItunesLoginResponse: if retries <= 0: logger.error("登录重试次数已用完") @@ -96,11 +104,11 @@ class AppleClient: response=ItunesFailLoginPlistData( status=30, failureType="MAX_RETRIES_EXCEEDED", - customerMessage="登录重试次数已用完" + customerMessage="登录重试次数已用完", ), - originLog="登录重试次数已用完" + originLog="登录重试次数已用完", ) - + headers = { "X-Apple-ActionSignature": sign_map.signature, "X-Apple-Store-Front": "143465-19,17", @@ -123,16 +131,18 @@ class AppleClient: "itsMetricsR": "Genre-CN-Mobile Software Applications-29099@@Mobile Software Applications-main@@@@", "s_vnum_n2_us": "0|1", } - + params = {} if server_id != "": url = f"https://p{server_id}-buy.itunes.apple.com/WebObjects/MZFinance.woa/wa/authenticate" params = {"Pod": server_id, "PRH": server_id} else: - url = "https://buy.itunes.apple.com/WebObjects/MZFinance.woa/wa/authenticate" - + url = ( + "https://buy.itunes.apple.com/WebObjects/MZFinance.woa/wa/authenticate" + ) + self.__session.headers["X-Apple-Store-Front"] = "143465-19,12" - + try: response = self.__session.post( url, @@ -152,7 +162,7 @@ class AppleClient: except requests.exceptions.RequestException as e: logger.error(f"请求错误: {str(e)}") return self.login(sign_map, account_info, server_id, retries - 1) - + if response.is_redirect: redirect_url = response.headers["Location"] groups = re.search(r"https://p(\d+)-buy.itunes.apple.com", redirect_url) @@ -166,11 +176,11 @@ class AppleClient: response=ItunesFailLoginPlistData( status=30, failureType="INVALID_REDIRECT", - customerMessage="无法解析重定向URL" + customerMessage="无法解析重定向URL", ), - originLog=response.text + originLog=response.text, ) - + try: response_dict_data = parse_xml(response.text) except Exception as e: @@ -178,39 +188,43 @@ class AppleClient: return ItunesLoginResponse( serverId=server_id, response=ItunesFailLoginPlistData( - status=30, - failureType="PARSE_ERROR", - customerMessage="解析响应失败" + status=30, failureType="PARSE_ERROR", customerMessage="解析响应失败" ), - originLog=response.text + originLog=response.text, ) - + if "failureType" in response_dict_data: status = 31 - if response_dict_data.get("metrics", {}).get("dialogId") == "MZFinance.AccountDisabled": + if ( + response_dict_data.get("metrics", {}).get("dialogId") + == "MZFinance.AccountDisabled" + ): status = 14 - if response_dict_data.get("metrics", {}).get("dialogId") == "MZFinance.DisabledAndFraudLocked": + if ( + response_dict_data.get("metrics", {}).get("dialogId") + == "MZFinance.DisabledAndFraudLocked" + ): status = 14 if response_dict_data.get("failureType") == "-5000": status = 13 if status == 31: logger.warning("登录状态未知:", response_dict_data) - response_model = ItunesFailLoginPlistData(**{"status": status, **response_dict_data}) + response_model = ItunesFailLoginPlistData( + **{"status": status, **response_dict_data} + ) else: response_model = ItunesSuccessLoginPlistData(**response_dict_data) - + return ItunesLoginResponse( - serverId=server_id, - response=response_model, - originLog=response.text + serverId=server_id, response=response_model, originLog=response.text ) def redeem( - self, - code: str, - itunes: ItunesLoginModel, - account_info: ItunesAccountInfo, - reties=5, + self, + code: str, + itunes: ItunesLoginModel, + account_info: ItunesAccountInfo, + reties=5, ) -> RedeemSuccessResponse | RedeemFailResponseModel: if reties <= 0: logger.error("兑换失败,兑换重试次数已用完") @@ -237,7 +251,7 @@ class AppleClient: "Content-Type": "application/x-apple-plist; Charset=UTF-8", "User-Agent": "MacAppStore/2.0 (Macintosh; OS X 12.10) AppleWebKit/600.1.3.41", "Referer": f"https://p{itunes.server_id}-buy.itunes.apple.com/WebObjects/MZFinance.woa/wa/com.apple" - f".jingle.app.finance.DirectAction/redeemCode?cl=iTunes&pg=Music", + f".jingle.app.finance.DirectAction/redeemCode?cl=iTunes&pg=Music", } try: response = self.__session.post( @@ -299,8 +313,8 @@ class AppleClient: result = RedeemFailResponseModel.model_validate(response.json()) result.origin_log = response.text if ( - result.errorMessageKey - == "MZCommerce.GiftCertificateAlreadyRedeemed" + result.errorMessageKey + == "MZCommerce.GiftCertificateAlreadyRedeemed" ): # 已经被兑换 result.status = 12 @@ -308,31 +322,28 @@ class AppleClient: # 没有这个卡密 result.status = 11 elif ( - result.errorMessageKey - == "MZCommerce.NatIdYearlyCapExceededException" + result.errorMessageKey + == "MZCommerce.NatIdYearlyCapExceededException" ): # 年限额 result.status = 31 elif ( - result.errorMessageKey - == "MZCommerce.NatIdDailyCapExceededException" + result.errorMessageKey + == "MZCommerce.NatIdDailyCapExceededException" ): # 日限额 result.status = 31 # 国籍问题 elif ( - result.errorMessageKey - == "MZCommerce.GiftCertRedeemStoreFrontMismatch" + result.errorMessageKey + == "MZCommerce.GiftCertRedeemStoreFrontMismatch" ): result.status = 15 - elif ( - result.errorMessageKey - == "MZCommerce.GiftCertificateDisabled" - ): + elif result.errorMessageKey == "MZCommerce.GiftCertificateDisabled": result.status = 11 else: logger.error(f"失败状态未知:{response.json()}") - + if result.status == -1 or result.status == 0: result.status = 30 logger.warning("兑换状态未知:", response.text) diff --git a/src/integrations/itunes/models/redeem.py b/src/integrations/itunes/models/redeem.py index 4496fda..6f20ec1 100644 --- a/src/integrations/itunes/models/redeem.py +++ b/src/integrations/itunes/models/redeem.py @@ -10,7 +10,9 @@ class RedeemFailResponseModel(BaseModel): default="", alias="userPresentableErrorMessage" ) origin_log: str = Field(default="") - origin_status_code: int = Field(default=0, alias="originStatusCode", description="原始状态码") + origin_status_code: int = Field( + default=0, alias="originStatusCode", description="原始状态码" + ) status: int = Field(..., alias="status", description="0.需要登录 1.正常") diff --git a/src/integrations/june/api.py b/src/integrations/june/api.py index fb8ef0d..c0519d0 100644 --- a/src/integrations/june/api.py +++ b/src/integrations/june/api.py @@ -28,15 +28,15 @@ class SixClient: # self.session = requests.Session() def _do_post( - self, post_data: Any, type_: str, start_now_fun: str = "0", reties: int = 3 + self, post_data: Any, type_: str, start_now_fun: str = "0", reties: int = 3 ) -> AppleSixResponseModel | None: if reties <= 0: return req_count = random.randint(0, 90) + 1 text = ( - str(int(time.time()) + req_count) - + str(Config.user_info.uid).zfill(4) - + str(req_count) + str(int(time.time()) + req_count) + + str(Config.user_info.uid).zfill(4) + + str(req_count) ).zfill(16) if len(text) > 16: text = text[:16] @@ -86,13 +86,13 @@ class SixClient: return self._do_post(post_data, type_, start_now_fun, reties - 1) if response.ok: if ( - response.headers["sign"] - and hashlib.md5( - ( + response.headers["sign"] + and hashlib.md5( + ( "abc_" + response.text + text + "by六月的风_联系qq:1023092054" - ).encode("utf-8") - ).hexdigest() - != response.headers["sign"] + ).encode("utf-8") + ).hexdigest() + != response.headers["sign"] ): raise Exception("签名错误") return AppleSixResponseModel(**response.json()) @@ -117,7 +117,7 @@ class SixClient: return False def login_remote_apple_account( - self, account: AppleAccountModel + self, account: AppleAccountModel ) -> AppleSixResponseModel[dict]: response = self._do_post( { @@ -134,7 +134,7 @@ class SixClient: return AppleSixResponseModel[dict].model_validate(response.model_dump()) def get_sign_sap_setup( - self, reties: int = 3 + self, reties: int = 3 ) -> AppleSixResponseModel[LoginSignatureModel] | None: if reties < 0: return @@ -165,11 +165,11 @@ class SixClient: return response def get_sign_sap_setup_cert( - self, - account: AppleAccountModel, - sign: AppleSixResponseModel[LoginSignatureModel], - sign_sap_setup: str, - reties: int = 3, + self, + account: AppleAccountModel, + sign: AppleSixResponseModel[LoginSignatureModel], + sign_sap_setup: str, + reties: int = 3, ) -> AppleSixResponseModel[AuthenticateModel] | None: if reties < 0: return @@ -197,8 +197,12 @@ class SixClient: try: response.Data = json.loads(decode_and_decompress(response.Data)) except AttributeError as e: - logger.error(f"获取cert失败,{response},错误信息:{traceback.format_exc()}") - return self.get_sign_sap_setup_cert(account, sign, sign_sap_setup, reties - 1) + logger.error( + f"获取cert失败,{response},错误信息:{traceback.format_exc()}" + ) + return self.get_sign_sap_setup_cert( + account, sign, sign_sap_setup, reties - 1 + ) if response.Data.get("msg") == "请重试": logger.info(f"重试六月登录,{response}") time.sleep(1) diff --git a/src/integrations/june/test_api.py b/src/integrations/june/test_api.py index 861a740..4ba04bc 100644 --- a/src/integrations/june/test_api.py +++ b/src/integrations/june/test_api.py @@ -6,4 +6,4 @@ from src.integrations.june.api import SixClient class TestSixClient(TestCase): def test_get_sign_sap_setup(self): result = SixClient().get_sign_sap_setup() - print(result) \ No newline at end of file + print(result) diff --git a/src/integrations/master_node/api.py b/src/integrations/master_node/api.py index 7fac5f6..4689fb2 100644 --- a/src/integrations/master_node/api.py +++ b/src/integrations/master_node/api.py @@ -12,10 +12,17 @@ from src.integrations.master_node.models import ( CommonResponseSchema, ItunesRedeemRequestModel, ) +from src.utils.crypto import ( + AESKey, + encrypt_with_aes, + encrypt_with_rsa, + decrypt_with_aes, +) # 当前节点id machineId = uuid.uuid4().hex + class MasterNodeService: single_lock = RLock() @@ -41,7 +48,13 @@ class MasterNodeService: self.__address, "/api/cardInfo/appleCard/rechargeOrder/handler" ), data={ - "machineId": machineId, + "machineId": encrypt_with_aes( + f"{machineId}:{int(time.time())}", + AESKey.load_from_base64( + "P0x6Gy6dXIpPbhE7PHxaHbfZHhsbT2qNPlx3qbHTP1o=" + ), + AESKey.load_from_base64("nywao1XkDXeYwbPeWh+SxA=="), + ), }, proxies={ "http": "", @@ -64,8 +77,16 @@ class MasterNodeService: if not result.data or not result.data.account: return RechargeQueryModel() - self.__account = result.data.account - self.__password = result.data.password + self.__account = decrypt_with_aes( + result.data.account, + AESKey.load_from_base64("P0x6Gy6dXIpPbhE7PHxaHbfZHhsbT2qNPlx3qbHTP1o="), + AESKey.load_from_base64("nywao1XkDXeYwbPeWh+SxA=="), + ) + self.__password = decrypt_with_aes( + result.data.password, + AESKey.load_from_base64("P0x6Gy6dXIpPbhE7PHxaHbfZHhsbT2qNPlx3qbHTP1o="), + AESKey.load_from_base64("nywao1XkDXeYwbPeWh+SxA=="), + ) return result.data # 修改充值状态 diff --git a/src/schema/models.py b/src/schema/models.py index a7fb453..ff49657 100644 --- a/src/schema/models.py +++ b/src/schema/models.py @@ -48,5 +48,5 @@ Base.metadata.create_all(engine) def get_session() -> sessionmaker[Session]: # 创建会话工厂 - session= sessionmaker(bind=engine) + session = sessionmaker(bind=engine) return session diff --git a/src/service/itunes.py b/src/service/itunes.py index 5571908..25b2aca 100644 --- a/src/service/itunes.py +++ b/src/service/itunes.py @@ -5,7 +5,8 @@ from loguru import logger from src.integrations.itunes.api import AppleClient from src.integrations.itunes.models.login import ( - ItunesSuccessLoginPlistData, ItunesAccountInfo, + ItunesSuccessLoginPlistData, + ItunesAccountInfo, ) from src.integrations.itunes.models.redeem import ( RedeemSuccessResponse, @@ -31,9 +32,9 @@ class ItunesService: self.apple_client_service = AppleClient() def login( - self, - account: AppleAccountModel, - retries: int = 3, + self, + account: AppleAccountModel, + retries: int = 3, ) -> LoginSuccessResponse | LoginFailureResponse: """ 登录itunes @@ -56,9 +57,12 @@ class ItunesService: ) middle_time_3 = time.time() logger.info(f"[+] 获取签到证书耗时(六月): {middle_time_3 - middle_time_2}") - login_schema = self.apple_client_service.login(sign_sap_cert.Data, ItunesAccountInfo( - account_name=account.account, - )) + login_schema = self.apple_client_service.login( + sign_sap_cert.Data, + ItunesAccountInfo( + account_name=account.account, + ), + ) logger.info(f"[+] 登录耗时(苹果): {time.time() - middle_time_3}") logger.info(f"[+] 登录耗时合计: {time.time() - start_time}") session = get_session()() @@ -108,7 +112,7 @@ class ItunesService: return response_result def redeem( - self, code: str, item: RedeemRequestModel, set_cookie: bool = False + self, code: str, item: RedeemRequestModel, set_cookie: bool = False ) -> AppleAccountRedeemResponse: """ 兑换代码 @@ -132,7 +136,7 @@ class ItunesService: ), ItunesAccountInfo( account_name=item.account_name, - ) + ), ) if isinstance(result, RedeemSuccessResponse): # 充值成功 diff --git a/src/service/proxy.py b/src/service/proxy.py index 0e988b4..09a4498 100644 --- a/src/service/proxy.py +++ b/src/service/proxy.py @@ -12,6 +12,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # 账户列表 account_list = {} + class ProxyService: def __init__(self) -> None: self.proxy_list = setting.proxies.address @@ -19,7 +20,7 @@ class ProxyService: self.test_urls = [ "https://www.baidu.com", "https://www.apple.com", - "https://buy.itunes.apple.com" + "https://buy.itunes.apple.com", ] self.failed_proxies = {} # 记录失败的代理 @@ -45,12 +46,14 @@ class ProxyService: proxy = "" if self.proxy_list: # 随机打乱代理列表顺序 - available_proxies = [p for p in self.proxy_list if not self.is_proxy_failed(p)] + available_proxies = [ + p for p in self.proxy_list if not self.is_proxy_failed(p) + ] if not available_proxies: # 如果所有代理都失败,重置失败记录 self.failed_proxies.clear() available_proxies = self.proxy_list - + random.shuffle(available_proxies) for proxy in available_proxies: if self.test_proxy(proxy): @@ -70,7 +73,9 @@ class ProxyService: def is_proxy_failed(self, proxy: str) -> bool: """检查代理是否在失败列表中且未过期""" if proxy in self.failed_proxies: - if time.time() - self.failed_proxies[proxy] < 300: # 5分钟内不重用失败的代理 + if ( + time.time() - self.failed_proxies[proxy] < 300 + ): # 5分钟内不重用失败的代理 return True else: del self.failed_proxies[proxy] @@ -82,17 +87,17 @@ class ProxyService: def test_proxy(self, proxy: str) -> bool: """测试代理是否可用 - + Args: proxy: 代理地址 - + Returns: bool: 代理是否可用 """ proxies = self.warp_proxy(proxy) if not proxies: return False - + for url in self.test_urls: try: response = requests.get( @@ -101,21 +106,23 @@ class ProxyService: timeout=self.proxy_timeout, verify=False, # 禁用 SSL 证书验证 headers={ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36', - 'Accept': '*/*', - 'Accept-Encoding': 'gzip, deflate', - 'Connection': 'keep-alive' - } + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate", + "Connection": "keep-alive", + }, ) if response.status_code == 200: return True - except (requests.exceptions.SSLError, - requests.exceptions.ProxyError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout) as e: + except ( + requests.exceptions.SSLError, + requests.exceptions.ProxyError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ) as e: logger.debug(f"代理测试失败: {proxy}, URL: {url}, 错误: {str(e)}") continue - + # 如果所有URL都测试失败,标记代理为失败 self.mark_proxy_failed(proxy) - return False \ No newline at end of file + return False diff --git a/src/service/test_itunes.py b/src/service/test_itunes.py index d7221a3..e6840d5 100644 --- a/src/service/test_itunes.py +++ b/src/service/test_itunes.py @@ -31,7 +31,9 @@ class TestItunesService(TestCase): AppleAccountModel( account=k, password=v, - ), sign_sap_from_june, sign_sap_setup_buffer + ), + sign_sap_from_june, + sign_sap_setup_buffer, ) middle_time_3 = time.time() print("get_sign_sap_setup_cert", middle_time_3 - middle_time_2) diff --git a/src/utils/crypto.py b/src/utils/crypto.py index 6c70125..11d683f 100644 --- a/src/utils/crypto.py +++ b/src/utils/crypto.py @@ -1,14 +1,16 @@ """ -非对称加密辅助函数模块 +加密算法辅助函数模块 -提供 RSA 密钥对生成、加密、解密、签名、验证等功能 +提供 RSA(非对称加密)和 AES(对称加密)密钥对生成、加密、解密、签名、验证等功能 """ import base64 -from typing import Tuple, Optional +import os +from typing import Tuple, Optional, Literal from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa, padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -111,7 +113,7 @@ class RSAKeyPair: def encrypt_with_rsa( - plaintext: str, public_key_pem: str, encoding: str = "base64" + plaintext: str, public_key_pem: str, encoding: Literal["base64", "hex"] = "base64" ) -> str: """ 使用 RSA 公钥加密数据(OAEP 填充) @@ -153,7 +155,7 @@ def encrypt_with_rsa( def decrypt_with_rsa( - ciphertext: str, private_key_pem: str, encoding: str = "base64" + ciphertext: str, private_key_pem: str, encoding: Literal["base64", "hex"] = "base64" ) -> str: """ 使用 RSA 私钥解密数据 @@ -195,7 +197,9 @@ def decrypt_with_rsa( return plaintext_bytes.decode("utf-8") -def sign_with_rsa(message: str, private_key_pem: str, encoding: str = "base64") -> str: +def sign_with_rsa( + message: str, private_key_pem: str, encoding: Literal["base64", "hex"] = "base64" +) -> str: """ 使用 RSA 私钥对消息进行数字签名 @@ -232,7 +236,10 @@ def sign_with_rsa(message: str, private_key_pem: str, encoding: str = "base64") def verify_rsa_signature( - message: str, signature: str, public_key_pem: str, encoding: str = "base64" + message: str, + signature: str, + public_key_pem: str, + encoding: Literal["base64", "hex"] = "base64", ) -> bool: """ 使用 RSA 公钥验证数字签名 @@ -260,8 +267,6 @@ def verify_rsa_signature( signature_bytes = base64.b64decode(signature) elif encoding == "hex": signature_bytes = bytes.fromhex(signature) - else: - raise ValueError(f"不支持的编码方式: {encoding}") # 验证签名 public_key.verify( @@ -278,7 +283,7 @@ def verify_rsa_signature( def generate_key_pair_files( - private_key_path: str, public_key_path: str, key_size: int = 2048 + private_key_path: str, public_key_path: str, key_size: Literal[2048, 4096] = 2048 ) -> Tuple[str, str]: """ 生成 RSA 密钥对并保存到文件 @@ -335,3 +340,189 @@ def load_key_pair_from_files( public_key = key_pair.public_key return RSAKeyPair(private_key, public_key) + + +# ==================== AES 对称加密功能 ==================== + + +class AESKey: + """ + AES 密钥管理类 + + 支持 AES-128, AES-192, AES-256 + """ + + def __init__(self, key: bytes): + """ + 初始化 AES 密钥 + + Args: + key: 密钥字节 (16, 24 或 32 字节分别对应 AES-128, AES-192, AES-256) + + Raises: + ValueError: 密钥長度有效 + """ + if len(key) not in (16, 24, 32): + raise ValueError("密钥大小必须是 16, 24 或 32 字节") + self.key = key + + @staticmethod + def generate(key_size: int = 32) -> "AESKey": + """ + 生成随機 AES 密钥 + + Args: + key_size: 密钥大小 (16, 24 或 32), 默認 32 (即 AES-256) + + Returns: + AESKey: 真機密钥对象 + + Raises: + ValueError: key_size 有效 + """ + if key_size not in (16, 24, 32): + raise ValueError("密钥大小必须是 16, 24 或 32") + key = os.urandom(key_size) + return AESKey(key) + + def get_key_hex(self) -> str: + """ + 获取 Hex 格式的密钥字符串 + + Returns: + str: Hex 编码的密钥 + """ + return self.key.hex() + + def get_key_base64(self) -> str: + """ + 获取 Base64 格式的密钥字符串 + + Returns: + str: Base64 编码的密钥 + """ + return base64.b64encode(self.key).decode("utf-8") + + @staticmethod + def load_from_hex(hex_key: str) -> "AESKey": + """ + 从 Hex 字符串加载密钥 + + Args: + hex_key: Hex 编码的密钥字符串 + + Returns: + AESKey: AES 密钥对象 + """ + key = bytes.fromhex(hex_key) + return AESKey(key) + + @staticmethod + def load_from_base64(base64_key: str) -> "AESKey": + """ + 从 Base64 字符串加载密钥 + + Args: + base64_key: Base64 编码的密钥字符串 + + Returns: + AESKey: AES 密钥对象 + """ + key = base64.b64decode(base64_key) + return AESKey(key) + + +def encrypt_with_aes( + plaintext: str, + aes_key: AESKey, + iv: AESKey, + encoding: Literal["base64", "hex"] = "base64", +) -> str: + """ + 使用 AES 密钥加密数据 (使用 CBC 模式和 PKCS7 填充) + + Args: + plaintext: 明文字符串 + aes_key: AES 密钥对象 + iv: 初始化向量对象 + encoding: 输出编码方式 ('base64' 或 'hex'), 默认 'base64' + + Returns: + str: 加密后的数据 (Base64 或 Hex 编码) + + Raises: + ValueError: 编码方式不支持 + """ + # 转换明文为字节 + plaintext_bytes = plaintext.encode("utf-8") + + # 应用 PKCS7 填充 + block_size = 16 + padding_length = block_size - (len(plaintext_bytes) % block_size) + plaintext_padded = plaintext_bytes + bytes([padding_length] * padding_length) + + # 创建 AES 加密器 + cipher = Cipher( + algorithms.AES(aes_key.key), + modes.CBC(iv.key), + backend=default_backend(), + ) + encryptor = cipher.encryptor() + ciphertext = encryptor.update(plaintext_padded) + encryptor.finalize() + + # 返回密文 (IV 由调用者保管) + encrypted_data = ciphertext + + if encoding == "base64": + return base64.b64encode(encrypted_data).decode("utf-8") + elif encoding == "hex": + return encrypted_data.hex() + return None + + +def decrypt_with_aes( + ciphertext: str, + aes_key: AESKey, + iv: AESKey, + encoding: Literal["base64", "hex"] = "base64", +) -> str: + """ + 使用 AES 密钥解密数据 + + Args: + ciphertext: 加密数据 (Base64 或 Hex 编码) + aes_key: AES 密钥对象 + iv: 初始化向量对象 + encoding: 输入编码方式 ('base64' 或 'hex'), 默认 'base64' + + Returns: + str: 解密后的明文字符串 + + Raises: + ValueError: 编码方式不支持或解密失败 + """ + # 解码密文 + if encoding == "base64": + encrypted_data = base64.b64decode(ciphertext) + elif encoding == "hex": + encrypted_data = bytes.fromhex(ciphertext) + else: + raise ValueError(f"不支持的编码方式: {encoding}") + + # 使用传入的 IV + actual_ciphertext = encrypted_data + + # 创建 AES 解密器 + cipher = Cipher( + algorithms.AES(aes_key.key), + modes.CBC(iv.key), + backend=default_backend(), + ) + decryptor = cipher.decryptor() + plaintext_padded = decryptor.update(actual_ciphertext) + decryptor.finalize() + + # 移除 PKCS7 填充 + padding_length = plaintext_padded[-1] + plaintext_bytes = plaintext_padded[:-padding_length] + + return plaintext_bytes.decode("utf-8") diff --git a/src/utils/examples.py b/src/utils/examples.py deleted file mode 100644 index ffa0f1c..0000000 --- a/src/utils/examples.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -非对称加密辅助函数使用示例 -""" - -from src.utils.crypto import ( - RSAKeyPair, - encrypt_with_rsa, - decrypt_with_rsa, - sign_with_rsa, - verify_rsa_signature, - generate_key_pair_files, - load_key_pair_from_files, -) - - -def example_generate_keys(): - """示例:生成 RSA 密钥对""" - print("=== 生成 RSA 密钥对 ===") - key_pair = RSAKeyPair.generate(key_size=2048) - - private_pem = key_pair.get_private_key_pem() - public_pem = key_pair.get_public_key_pem() - - print("私钥(PEM 格式,前 100 字符):") - print(private_pem[:100] + "...") - print("\n公钥(PEM 格式,前 100 字符):") - print(public_pem[:100] + "...") - - return private_pem, public_pem - - -def example_encrypt_decrypt(private_pem, public_pem): - """示例:加密和解密""" - print("\n=== RSA 加密和解密 ===") - - plaintext = "Hello, World! This is a secret message." - print(f"明文: {plaintext}") - - # 加密 - ciphertext = encrypt_with_rsa(plaintext, public_pem, encoding="base64") - print(f"密文(Base64): {ciphertext[:50]}...") - - # 解密 - decrypted = decrypt_with_rsa(ciphertext, private_pem, encoding="base64") - print(f"解密后的明文: {decrypted}") - print(f"解密成功: {decrypted == plaintext}") - - return ciphertext - - -def example_sign_verify(private_pem, public_pem): - """示例:数字签名和验证""" - print("\n=== RSA 数字签名和验证 ===") - - message = "This message needs to be signed" - print(f"原始消息: {message}") - - # 签名 - signature = sign_with_rsa(message, private_pem, encoding="base64") - print(f"签名(Base64): {signature[:50]}...") - - # 验证签名 - is_valid = verify_rsa_signature(message, signature, public_pem, encoding="base64") - print(f"签名有效: {is_valid}") - - # 尝试验证被篡改的消息 - tampered_message = "This message has been tampered with" - is_tampered_valid = verify_rsa_signature( - tampered_message, signature, public_pem, encoding="base64" - ) - print(f"篡改消息的签名有效: {is_tampered_valid}") - - return signature - - -def example_key_files(): - """示例:保存和加载密钥文件""" - print("\n=== 密钥文件操作 ===") - - private_path = "/tmp/private_key.pem" - public_path = "/tmp/public_key.pem" - - # 生成并保存密钥对 - private_pem, public_pem = generate_key_pair_files(private_path, public_path) - print(f"密钥对已保存到:") - print(f" 私钥: {private_path}") - print(f" 公钥: {public_path}") - - # 从文件加载密钥对 - loaded_key_pair = load_key_pair_from_files( - private_key_path=private_path, public_key_path=public_path - ) - print(f"密钥对已从文件加载") - print(f" 包含私钥: {loaded_key_pair.private_key is not None}") - print(f" 包含公钥: {loaded_key_pair.public_key is not None}") - - -def example_hex_encoding(): - """示例:使用 Hex 编码替代 Base64""" - print("\n=== Hex 编码示例 ===") - - key_pair = RSAKeyPair.generate(key_size=2048) - private_pem = key_pair.get_private_key_pem() - public_pem = key_pair.get_public_key_pem() - - plaintext = "Secret data" - print(f"明文: {plaintext}") - - # 使用 Hex 编码加密 - ciphertext_hex = encrypt_with_rsa(plaintext, public_pem, encoding="hex") - print(f"密文(Hex): {ciphertext_hex[:50]}...") - - # 使用 Hex 编码解密 - decrypted = decrypt_with_rsa(ciphertext_hex, private_pem, encoding="hex") - print(f"解密后的明文: {decrypted}") - - -if __name__ == "__main__": - # 运行所有示例 - private_pem, public_pem = example_generate_keys() - example_encrypt_decrypt(private_pem, public_pem) - example_sign_verify(private_pem, public_pem) - example_key_files() - example_hex_encoding() - - print("\n✓ 所有示例执行完成!") diff --git a/src/utils/integration_examples.py b/src/utils/integration_examples.py deleted file mode 100644 index bbaf119..0000000 --- a/src/utils/integration_examples.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -非对称加密辅助函数在项目中的集成示例 -""" - -import json -from datetime import datetime -from typing import Optional, Dict, Any - -from src.utils.crypto import ( - RSAKeyPair, - encrypt_with_rsa, - decrypt_with_rsa, - sign_with_rsa, - verify_rsa_signature, -) - - -class SecureAPIRequest: - """ - 安全的 API 请求签名示例 - - 用于与第三方 API 进行安全交互,确保请求的真实性和完整性 - """ - - def __init__(self, private_key_pem: str, app_id: str): - """ - 初始化安全 API 请求 - - Args: - private_key_pem: 应用的 RSA 私钥(PEM 格式) - app_id: 应用 ID - """ - self.private_key_pem = private_key_pem - self.app_id = app_id - - def create_signed_request( - self, endpoint: str, data: Dict[str, Any] - ) -> Dict[str, Any]: - """ - 创建带签名的 API 请求 - - Args: - endpoint: API 端点路径 - data: 请求数据 - - Returns: - Dict: 包含签名的完整请求体 - """ - timestamp = str(int(datetime.now().timestamp())) - - # 构建请求内容 - request_content = json.dumps( - { - "app_id": self.app_id, - "endpoint": endpoint, - "timestamp": timestamp, - "data": data, - }, - separators=(",", ":"), - ) - - # 对请求内容签名 - signature = sign_with_rsa(request_content, self.private_key_pem) - - return { - "request": request_content, - "signature": signature, - "app_id": self.app_id, - } - - @staticmethod - def verify_request(request_json: str, signature: str, public_key_pem: str) -> bool: - """ - 验证 API 请求签名 - - Args: - request_json: 请求内容(JSON 字符串) - signature: 请求签名 - public_key_pem: 应用的 RSA 公钥(PEM 格式) - - Returns: - bool: 签名是否有效 - """ - return verify_rsa_signature(request_json, signature, public_key_pem) - - -class SecureDataEncryption: - """ - 敏感数据加密示例 - - 用于在存储或传输敏感数据时进行加密 - """ - - def __init__(self, public_key_pem: str, private_key_pem: str): - """ - 初始化数据加密工具 - - Args: - public_key_pem: 公钥(用于加密) - private_key_pem: 私钥(用于解密) - """ - self.public_key_pem = public_key_pem - self.private_key_pem = private_key_pem - - def encrypt_user_data(self, user_id: str, sensitive_data: str) -> Dict[str, str]: - """ - 加密用户敏感数据 - - Args: - user_id: 用户 ID - sensitive_data: 敏感数据(如密码、令牌等) - - Returns: - Dict: 包含加密数据和元数据的字典 - """ - # 为了避免超过 RSA 加密限制,先对数据进行摘要处理 - # 如果数据较大,应使用混合加密方式 - encrypted = encrypt_with_rsa(sensitive_data, self.public_key_pem) - - return { - "user_id": user_id, - "encrypted_data": encrypted, - "timestamp": datetime.now().isoformat(), - "algorithm": "RSA-OAEP", - } - - def decrypt_user_data(self, encrypted_data: str) -> str: - """ - 解密用户敏感数据 - - Args: - encrypted_data: 加密的数据 - - Returns: - str: 解密后的原始数据 - """ - return decrypt_with_rsa(encrypted_data, self.private_key_pem) - - -class CertificateAuthority: - """ - 简单的证书颁发机构实现 - - 用于证书的签发和验证 - """ - - def __init__(self, ca_private_key_pem: str, ca_public_key_pem: str): - """ - 初始化 CA - - Args: - ca_private_key_pem: CA 的私钥 - ca_public_key_pem: CA 的公钥 - """ - self.ca_private_key_pem = ca_private_key_pem - self.ca_public_key_pem = ca_public_key_pem - - def issue_certificate( - self, subject: str, subject_public_key_pem: str, validity_days: int = 365 - ) -> Dict[str, Any]: - """ - 颁发证书 - - Args: - subject: 证书主体(如用户 ID、服务器名称等) - subject_public_key_pem: 主体的公钥 - validity_days: 有效期(天数) - - Returns: - Dict: 颁发的证书 - """ - now = datetime.now().isoformat() - valid_until = datetime.fromtimestamp( - datetime.now().timestamp() + validity_days * 86400 - ).isoformat() - - cert_data = json.dumps( - { - "subject": subject, - "public_key": subject_public_key_pem, - "issued_at": now, - "valid_until": valid_until, - "issuer": "CA", - } - ) - - # CA 对证书进行签名 - signature = sign_with_rsa(cert_data, self.ca_private_key_pem) - - return { - "certificate": cert_data, - "signature": signature, - "issued_at": now, - } - - def verify_certificate(self, cert_data: str, signature: str) -> bool: - """ - 验证证书签名 - - Args: - cert_data: 证书数据 - signature: 证书签名 - - Returns: - bool: 证书是否有效 - """ - return verify_rsa_signature(cert_data, signature, self.ca_public_key_pem) - - -class AuthenticationToken: - """ - 认证令牌签名示例 - - 用于生成和验证认证令牌 - """ - - def __init__(self, private_key_pem: str, public_key_pem: str, issuer: str): - """ - 初始化令牌生成器 - - Args: - private_key_pem: 私钥(用于签名) - public_key_pem: 公钥(用于验证) - issuer: 令牌颁发者 - """ - self.private_key_pem = private_key_pem - self.public_key_pem = public_key_pem - self.issuer = issuer - - def generate_token( - self, user_id: str, permissions: list, expires_in_hours: int = 24 - ) -> str: - """ - 生成签名的认证令牌 - - Args: - user_id: 用户 ID - permissions: 权限列表 - expires_in_hours: 过期时间(小时) - - Returns: - str: 签名的令牌 - """ - issued_at = datetime.now().isoformat() - expires_at = datetime.fromtimestamp( - datetime.now().timestamp() + expires_in_hours * 3600 - ).isoformat() - - token_data = json.dumps( - { - "user_id": user_id, - "permissions": permissions, - "issuer": self.issuer, - "issued_at": issued_at, - "expires_at": expires_at, - } - ) - - # 对令牌签名 - signature = sign_with_rsa(token_data, self.private_key_pem) - - # 将令牌和签名组合(使用分隔符) - return f"{token_data}.{signature}" - - def verify_token(self, token: str) -> Optional[Dict[str, Any]]: - """ - 验证令牌并返回令牌数据 - - Args: - token: 签名的令牌字符串 - - Returns: - Optional[Dict]: 令牌数据(如果有效)或 None(如果无效) - """ - try: - token_data, signature = token.rsplit(".", 1) - - # 验证签名 - if not verify_rsa_signature(token_data, signature, self.public_key_pem): - return None - - # 解析令牌数据 - token_dict = json.loads(token_data) - - # 验证过期时间 - expires_at = datetime.fromisoformat(token_dict["expires_at"]) - if datetime.now() > expires_at: - return None - - return token_dict - except Exception: - return None - - -# 使用示例 -if __name__ == "__main__": - # 生成密钥对 - print("=== 生成密钥对 ===") - key_pair = RSAKeyPair.generate(key_size=2048) - private_pem = key_pair.get_private_key_pem() - public_pem = key_pair.get_public_key_pem() - - # 示例 1: 安全 API 请求 - print("\n=== 安全 API 请求 ===") - api_request = SecureAPIRequest(private_pem, app_id="app_001") - signed_request = api_request.create_signed_request( - endpoint="/api/users/login", - data={"username": "user@example.com", "password": "secret"}, - ) - print(f"签名请求: {signed_request['signature'][:50]}...") - is_valid = SecureAPIRequest.verify_request( - signed_request["request"], signed_request["signature"], public_pem - ) - print(f"请求签名有效: {is_valid}") - - # 示例 2: 数据加密 - print("\n=== 数据加密 ===") - data_encryption = SecureDataEncryption(public_pem, private_pem) - encrypted = data_encryption.encrypt_user_data( - user_id="user_123", sensitive_data="my_secret_token_12345" - ) - print(f"加密数据: {encrypted['encrypted_data'][:50]}...") - decrypted = data_encryption.decrypt_user_data(encrypted["encrypted_data"]) - print(f"解密后: {decrypted}") - - # 示例 3: 证书签发 - print("\n=== 证书颁发 ===") - ca = CertificateAuthority(private_pem, public_pem) - certificate = ca.issue_certificate( - subject="server.example.com", subject_public_key_pem=public_pem - ) - print(f"证书签名: {certificate['signature'][:50]}...") - is_valid = ca.verify_certificate( - certificate["certificate"], certificate["signature"] - ) - print(f"证书有效: {is_valid}") - - # 示例 4: 认证令牌 - print("\n=== 认证令牌 ===") - auth_token = AuthenticationToken(private_pem, public_pem, issuer="auth_server") - token = auth_token.generate_token( - user_id="user_456", permissions=["read", "write"], expires_in_hours=24 - ) - print(f"生成的令牌: {token[:50]}...") - token_data = auth_token.verify_token(token) - print(f"令牌数据: {token_data}") - - print("\n✓ 所有集成示例执行完成!") diff --git a/src/utils/test_crypto.py b/src/utils/test_crypto.py index 1a24e78..d143bf5 100644 --- a/src/utils/test_crypto.py +++ b/src/utils/test_crypto.py @@ -1,200 +1,13 @@ -""" -非对称加密辅助函数单元测试 -""" +from unittest import TestCase -import unittest - -from src.utils.crypto import ( - RSAKeyPair, - encrypt_with_rsa, - decrypt_with_rsa, - sign_with_rsa, - verify_rsa_signature, -) +from src.utils.crypto import encrypt_with_aes, AESKey -class TestRSAKeyPair(unittest.TestCase): - """测试 RSAKeyPair 类""" - - def setUp(self): - """测试前准备""" - self.key_pair = RSAKeyPair.generate(key_size=2048) - - def test_generate_keys(self): - """测试生成密钥对""" - self.assertIsNotNone(self.key_pair.private_key) - self.assertIsNotNone(self.key_pair.public_key) - - def test_get_private_key_pem(self): - """测试获取 PEM 格式私钥""" - private_pem = self.key_pair.get_private_key_pem() - self.assertIn("-----BEGIN PRIVATE KEY-----", private_pem) - self.assertIn("-----END PRIVATE KEY-----", private_pem) - - def test_get_public_key_pem(self): - """测试获取 PEM 格式公钥""" - public_pem = self.key_pair.get_public_key_pem() - self.assertIn("-----BEGIN PUBLIC KEY-----", public_pem) - self.assertIn("-----END PUBLIC KEY-----", public_pem) - - def test_load_private_key_from_pem(self): - """测试从 PEM 字符串加载私钥""" - private_pem = self.key_pair.get_private_key_pem() - loaded_key_pair = RSAKeyPair.load_private_key_from_pem(private_pem) - - self.assertIsNotNone(loaded_key_pair.private_key) - self.assertIsNotNone(loaded_key_pair.public_key) - - def test_load_public_key_from_pem(self): - """测试从 PEM 字符串加载公钥""" - public_pem = self.key_pair.get_public_key_pem() - loaded_key_pair = RSAKeyPair.load_public_key_from_pem(public_pem) - - self.assertIsNone(loaded_key_pair.private_key) - self.assertIsNotNone(loaded_key_pair.public_key) - - -class TestRSAEncryption(unittest.TestCase): - """测试 RSA 加密和解密""" - - def setUp(self): - """测试前准备""" - self.key_pair = RSAKeyPair.generate(key_size=2048) - self.private_pem = self.key_pair.get_private_key_pem() - self.public_pem = self.key_pair.get_public_key_pem() - self.plaintext = "Hello, World!" - - def test_encrypt_decrypt_base64(self): - """测试使用 Base64 编码的加密和解密""" - ciphertext = encrypt_with_rsa( - self.plaintext, self.public_pem, encoding="base64" +class Test(TestCase): + def test_encrypt_with_aes(self): + result = encrypt_with_aes( + "34324345345345", + AESKey.load_from_base64("P0x6Gy6dXIpPbhE7PHxaHbfZHhsbT2qNPlx3qbHTP1o="), + AESKey.load_from_base64("nywao1XkDXeYwbPeWh+SxA=="), ) - decrypted = decrypt_with_rsa(ciphertext, self.private_pem, encoding="base64") - - self.assertEqual(self.plaintext, decrypted) - self.assertNotEqual(self.plaintext, ciphertext) - - def test_encrypt_decrypt_hex(self): - """测试使用 Hex 编码的加密和解密""" - ciphertext = encrypt_with_rsa(self.plaintext, self.public_pem, encoding="hex") - decrypted = decrypt_with_rsa(ciphertext, self.private_pem, encoding="hex") - - self.assertEqual(self.plaintext, decrypted) - self.assertNotEqual(self.plaintext, ciphertext) - - def test_encrypt_unicode(self): - """测试加密 Unicode 文本""" - unicode_text = "你好,世界!🔐" - ciphertext = encrypt_with_rsa(unicode_text, self.public_pem) - decrypted = decrypt_with_rsa(ciphertext, self.private_pem) - - self.assertEqual(unicode_text, decrypted) - - def test_encrypt_with_invalid_encoding(self): - """测试使用无效编码方式加密""" - with self.assertRaises(ValueError): - encrypt_with_rsa(self.plaintext, self.public_pem, encoding="invalid") - - def test_decrypt_with_invalid_encoding(self): - """测试使用无效编码方式解密""" - ciphertext = encrypt_with_rsa( - self.plaintext, self.public_pem, encoding="base64" - ) - with self.assertRaises(ValueError): - decrypt_with_rsa(ciphertext, self.private_pem, encoding="invalid") - - -class TestRSASignature(unittest.TestCase): - """测试 RSA 数字签名""" - - def setUp(self): - """测试前准备""" - self.key_pair = RSAKeyPair.generate(key_size=2048) - self.private_pem = self.key_pair.get_private_key_pem() - self.public_pem = self.key_pair.get_public_key_pem() - self.message = "This is a message to sign" - - def test_sign_verify_base64(self): - """测试使用 Base64 编码的签名和验证""" - signature = sign_with_rsa(self.message, self.private_pem, encoding="base64") - is_valid = verify_rsa_signature( - self.message, signature, self.public_pem, encoding="base64" - ) - - self.assertTrue(is_valid) - - def test_sign_verify_hex(self): - """测试使用 Hex 编码的签名和验证""" - signature = sign_with_rsa(self.message, self.private_pem, encoding="hex") - is_valid = verify_rsa_signature( - self.message, signature, self.public_pem, encoding="hex" - ) - - self.assertTrue(is_valid) - - def test_verify_invalid_signature(self): - """测试验证无效签名""" - signature = sign_with_rsa(self.message, self.private_pem) - # 修改消息 - tampered_message = "This is a tampered message" - is_valid = verify_rsa_signature( - tampered_message, signature, self.public_pem, encoding="base64" - ) - - self.assertFalse(is_valid) - - def test_verify_corrupted_signature(self): - """测试验证损坏的签名""" - signature = sign_with_rsa(self.message, self.private_pem) - # 修改签名 - corrupted_signature = signature[:-10] + "corrupted" - is_valid = verify_rsa_signature( - self.message, corrupted_signature, self.public_pem, encoding="base64" - ) - - self.assertFalse(is_valid) - - def test_sign_unicode_message(self): - """测试签名 Unicode 消息""" - unicode_message = "签名测试消息 🔐" - signature = sign_with_rsa(unicode_message, self.private_pem) - is_valid = verify_rsa_signature(unicode_message, signature, self.public_pem) - - self.assertTrue(is_valid) - - -class TestEdgeCases(unittest.TestCase): - """测试边界情况""" - - def setUp(self): - """测试前准备""" - self.key_pair = RSAKeyPair.generate(key_size=2048) - self.private_pem = self.key_pair.get_private_key_pem() - self.public_pem = self.key_pair.get_public_key_pem() - - def test_encrypt_empty_string(self): - """测试加密空字符串""" - plaintext = "" - ciphertext = encrypt_with_rsa(plaintext, self.public_pem) - decrypted = decrypt_with_rsa(ciphertext, self.private_pem) - - self.assertEqual(plaintext, decrypted) - - def test_encrypt_long_message(self): - """测试加密长消息(应该失败,RSA 有长度限制)""" - # RSA 2048 位密钥最多能加密约 190 字节 - long_plaintext = "A" * 300 - with self.assertRaises(Exception): - encrypt_with_rsa(long_plaintext, self.public_pem) - - def test_sign_empty_message(self): - """测试对空消息签名""" - message = "" - signature = sign_with_rsa(message, self.private_pem) - is_valid = verify_rsa_signature(message, signature, self.public_pem) - - self.assertTrue(is_valid) - - -if __name__ == "__main__": - unittest.main() + print(result)