From 95c740dee242544ca743856108f4913230cef4e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=BA=8F=E5=91=98=E9=98=BF=E6=B1=9F-Relakkes?= Date: Fri, 26 Sep 2025 17:38:44 +0800 Subject: [PATCH] refine: harden typer cli defaults --- cmd_arg/arg.py | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/cmd_arg/arg.py b/cmd_arg/arg.py index 1c08b3f..ba20e65 100644 --- a/cmd_arg/arg.py +++ b/cmd_arg/arg.py @@ -11,7 +11,7 @@ from enum import Enum from types import SimpleNamespace -from typing import Optional +from typing import Iterable, Optional, Sequence, Type, TypeVar import typer from typing_extensions import Annotated @@ -20,6 +20,9 @@ import config from tools.utils import str2bool +EnumT = TypeVar("EnumT", bound=Enum) + + class PlatformEnum(str, Enum): """支持的媒体平台枚举""" @@ -70,7 +73,33 @@ def _to_bool(value: bool | str) -> bool: return str2bool(value) -async def parse_cmd(): +def _coerce_enum( + enum_cls: Type[EnumT], + value: EnumT | str, + default: EnumT, +) -> EnumT: + """Safely convert a raw config value to an enum member.""" + + if isinstance(value, enum_cls): + return value + + try: + return enum_cls(value) + except ValueError: + typer.secho( + f"⚠️ 配置值 '{value}' 不在 {enum_cls.__name__} 支持的范围内,已回退到默认值 '{default.value}'.", + fg=typer.colors.YELLOW, + ) + return default + + +def _normalize_argv(argv: Optional[Sequence[str]]) -> Optional[Iterable[str]]: + if argv is None: + return None + return list(argv) + + +async def parse_cmd(argv: Optional[Sequence[str]] = None): """使用 Typer 解析命令行参数。""" def main( @@ -81,7 +110,7 @@ async def parse_cmd(): help="媒体平台选择 (xhs=小红书 | dy=抖音 | ks=快手 | bili=哔哩哔哩 | wb=微博 | tieba=百度贴吧 | zhihu=知乎)", rich_help_panel="基础配置", ), - ] = PlatformEnum(config.PLATFORM), + ] = _coerce_enum(PlatformEnum, config.PLATFORM, PlatformEnum.XHS), lt: Annotated[ LoginTypeEnum, typer.Option( @@ -89,7 +118,7 @@ async def parse_cmd(): help="登录方式 (qrcode=二维码 | phone=手机号 | cookie=Cookie)", rich_help_panel="账号配置", ), - ] = LoginTypeEnum(config.LOGIN_TYPE), + ] = _coerce_enum(LoginTypeEnum, config.LOGIN_TYPE, LoginTypeEnum.QRCODE), crawler_type: Annotated[ CrawlerTypeEnum, typer.Option( @@ -97,7 +126,7 @@ async def parse_cmd(): help="爬取类型 (search=搜索 | detail=详情 | creator=创作者)", rich_help_panel="基础配置", ), - ] = CrawlerTypeEnum(config.CRAWLER_TYPE), + ] = _coerce_enum(CrawlerTypeEnum, config.CRAWLER_TYPE, CrawlerTypeEnum.SEARCH), start: Annotated[ int, typer.Option( @@ -139,7 +168,9 @@ async def parse_cmd(): help="数据保存方式 (csv=CSV文件 | db=MySQL数据库 | json=JSON文件 | sqlite=SQLite数据库)", rich_help_panel="存储配置", ), - ] = SaveDataOptionEnum(config.SAVE_DATA_OPTION), + ] = _coerce_enum( + SaveDataOptionEnum, config.SAVE_DATA_OPTION, SaveDataOptionEnum.JSON + ), init_db: Annotated[ Optional[InitDbOptionEnum], typer.Option( @@ -190,6 +221,6 @@ async def parse_cmd(): command = typer.main.get_command(main) try: - return command.main(standalone_mode=False) + return command.main(args=_normalize_argv(argv), standalone_mode=False) except typer.Exit as exc: # pragma: no cover - CLI exit paths raise SystemExit(exc.exit_code) from exc