diff --git a/api/schemas/crawler.py b/api/schemas/crawler.py index 283cf2d..6eb3b1b 100644 --- a/api/schemas/crawler.py +++ b/api/schemas/crawler.py @@ -18,7 +18,10 @@ from enum import Enum from typing import Optional, Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +MAX_API_LIMIT_COUNT = 10000 class PlatformEnum(str, Enum): @@ -71,8 +74,8 @@ class CrawlerStartRequest(BaseModel): save_option: SaveDataOptionEnum = SaveDataOptionEnum.JSONL cookies: str = "" headless: bool = False - max_notes_count: Optional[int] = None - max_comments_count: Optional[int] = None + max_notes_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT) + max_comments_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT) class CrawlerStatusResponse(BaseModel): diff --git a/cmd_arg/arg.py b/cmd_arg/arg.py index b18b09e..074dd13 100644 --- a/cmd_arg/arg.py +++ b/cmd_arg/arg.py @@ -320,10 +320,18 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None): str, typer.Option( "--ip_proxy_provider_name", - help="IP proxy provider name (kuaidaili | wandouhttp)", + help="IP proxy provider name (kuaidaili | wandouhttp | static)", rich_help_panel="Proxy Configuration", ), ] = config.IP_PROXY_PROVIDER_NAME, + static_proxy_url: Annotated[ + str, + typer.Option( + "--static_proxy_url", + help="Static proxy URL, for example http://user:password@host:port", + rich_help_panel="Proxy Configuration", + ), + ] = config.STATIC_PROXY_URL, ) -> SimpleNamespace: """MediaCrawler 命令行入口""" @@ -356,6 +364,7 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None): config.ENABLE_IP_PROXY = enable_ip_proxy_value config.IP_PROXY_POOL_COUNT = ip_proxy_pool_count config.IP_PROXY_PROVIDER_NAME = ip_proxy_provider_name + config.STATIC_PROXY_URL = static_proxy_url # Set platform-specific ID lists for detail/creator mode if specified_id_list: diff --git a/config/base_config.py b/config/base_config.py index b2f084d..28a852e 100644 --- a/config/base_config.py +++ b/config/base_config.py @@ -34,14 +34,14 @@ CRAWLER_TYPE = ( ENABLE_IP_PROXY = False # Number of proxy IP pools -IP_PROXY_POOL_COUNT = 1 +IP_PROXY_POOL_COUNT = 2 # Proxy IP provider name -IP_PROXY_PROVIDER_NAME = "static" # kuaidaili | wandouhttp | static +IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp | static # Static proxy configuration (used when IP_PROXY_PROVIDER_NAME is set to "static") # Format: "http://your_home_domain:port" or "http://user:password@your_home_domain:port" -STATIC_PROXY_URL = "http://your_home_domain:port" +STATIC_PROXY_URL = "" # Setting to True will not open the browser (headless browser) # Setting False will open a browser diff --git a/proxy/proxy_ip_pool.py b/proxy/proxy_ip_pool.py index 95289dc..8041d38 100644 --- a/proxy/proxy_ip_pool.py +++ b/proxy/proxy_ip_pool.py @@ -22,7 +22,9 @@ # @Time : 2023/12/2 13:45 # @Desc : IP proxy pool implementation import random +import time from typing import Dict, List +from urllib.parse import unquote, urlparse import httpx from tenacity import retry, stop_after_attempt, wait_fixed @@ -152,33 +154,33 @@ class ProxyIpPool: class StaticProxyProvider(ProxyProvider): async def get_proxy(self, num: int) -> List[IpInfoModel]: - from urllib.parse import urlparse - import time - proxy_url = getattr(config, "STATIC_PROXY_URL", "") if not proxy_url: utils.logger.warning("[StaticProxyProvider] STATIC_PROXY_URL is not configured!") return [] - + try: parsed = urlparse(proxy_url) + scheme = parsed.scheme or "http" + if scheme not in {"http", "https"}: + utils.logger.error(f"[StaticProxyProvider] Unsupported proxy scheme: {scheme}") + return [] + ip = parsed.hostname or "" - port = parsed.port or 80 - user = parsed.username or "" - password = parsed.password or "" - protocol = parsed.scheme + "://" if parsed.scheme else "http://" - - # Static proxy doesn't expire - expired_time_ts = int(time.time()) + 99999999 - + port = parsed.port or (443 if scheme == "https" else 80) + if not ip: + utils.logger.error("[StaticProxyProvider] STATIC_PROXY_URL host is empty!") + return [] + return [ IpInfoModel( ip=ip, port=port, - user=user, - password=password, - protocol=protocol, - expired_time_ts=expired_time_ts + user=unquote(parsed.username or ""), + password=unquote(parsed.password or ""), + protocol=f"{scheme}://", + # Static proxy doesn't expire. + expired_time_ts=int(time.time()) + 99999999, ) ] except Exception as e: @@ -189,7 +191,7 @@ class StaticProxyProvider(ProxyProvider): IpProxyProvider: Dict[str, ProxyProvider] = { ProviderNameEnum.KUAI_DAILI_PROVIDER.value: new_kuai_daili_proxy(), ProviderNameEnum.WANDOU_HTTP_PROVIDER.value: new_wandou_http_proxy(), - "static": StaticProxyProvider(), + ProviderNameEnum.STATIC_PROVIDER.value: StaticProxyProvider(), } @@ -200,7 +202,7 @@ async def create_ip_pool(ip_pool_count: int, enable_validate_ip: bool) -> ProxyI :param enable_validate_ip: Whether to enable IP proxy validation :return: """ - is_static = config.IP_PROXY_PROVIDER_NAME == "static" + is_static = config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.STATIC_PROVIDER.value pool = ProxyIpPool( ip_pool_count=ip_pool_count, enable_validate_ip=False if is_static else enable_validate_ip, diff --git a/proxy/types.py b/proxy/types.py index e203141..57f3fe1 100644 --- a/proxy/types.py +++ b/proxy/types.py @@ -32,6 +32,7 @@ from pydantic import BaseModel, Field class ProviderNameEnum(Enum): KUAI_DAILI_PROVIDER: str = "kuaidaili" WANDOU_HTTP_PROVIDER: str = "wandouhttp" + STATIC_PROVIDER: str = "static" class IpInfoModel(BaseModel): diff --git a/tests/test_api_limits.py b/tests/test_api_limits.py index 3154684..0cb65bc 100644 --- a/tests/test_api_limits.py +++ b/tests/test_api_limits.py @@ -13,7 +13,7 @@ async def test_cmd_arg_crawler_max_notes_count(): # Store original values orig_notes = config.CRAWLER_MAX_NOTES_COUNT orig_comments = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES - + try: await parse_cmd([ "--platform", "xhs", @@ -28,7 +28,7 @@ async def test_cmd_arg_crawler_max_notes_count(): def test_crawler_manager_build_command(): cm = CrawlerManager() - + # 1. No max limits passed in API request req1 = CrawlerStartRequest( platform=PlatformEnum.XHS, @@ -64,10 +64,10 @@ def test_crawler_manager_build_command(): def test_api_start_crawler_with_limits(): client = TestClient(app) - + with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start: mock_start.return_value = True - + # Test case 1: with limits response = client.post("/api/crawler/start", json={ "platform": "xhs", @@ -77,10 +77,10 @@ def test_api_start_crawler_with_limits(): "max_notes_count": 50, "max_comments_count": 5 }) - + assert response.status_code == 200 assert response.json() == {"status": "ok", "message": "Crawler started successfully"} - + mock_start.assert_called_once() called_request = mock_start.call_args[0][0] assert called_request.platform == PlatformEnum.XHS @@ -89,10 +89,10 @@ def test_api_start_crawler_with_limits(): def test_api_start_crawler_without_limits(): client = TestClient(app) - + with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start: mock_start.return_value = True - + # Test case 2: without limits response = client.post("/api/crawler/start", json={ "platform": "xhs", @@ -100,10 +100,38 @@ def test_api_start_crawler_without_limits(): "crawler_type": "search", "keywords": "test" }) - + assert response.status_code == 200 mock_start.assert_called_once() called_request = mock_start.call_args[0][0] assert called_request.platform == PlatformEnum.XHS assert called_request.max_notes_count is None assert called_request.max_comments_count is None + + +@pytest.mark.parametrize( + ("field_name", "value"), + [ + ("max_notes_count", 0), + ("max_notes_count", -1), + ("max_notes_count", 10001), + ("max_comments_count", 0), + ("max_comments_count", -1), + ("max_comments_count", 10001), + ], +) +def test_api_rejects_invalid_limits(field_name, value): + client = TestClient(app) + payload = { + "platform": "xhs", + "login_type": "qrcode", + "crawler_type": "search", + "keywords": "test", + field_name: value, + } + + with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start: + response = client.post("/api/crawler/start", json=payload) + + assert response.status_code == 422 + mock_start.assert_not_called() diff --git a/tests/test_static_proxy_provider.py b/tests/test_static_proxy_provider.py new file mode 100644 index 0000000..c88647c --- /dev/null +++ b/tests/test_static_proxy_provider.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import pytest + +import config +from proxy.proxy_ip_pool import StaticProxyProvider, create_ip_pool +from proxy.types import ProviderNameEnum + + +def test_default_proxy_provider_remains_existing_provider(): + assert config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.KUAI_DAILI_PROVIDER.value + assert config.IP_PROXY_POOL_COUNT == 2 + assert config.STATIC_PROXY_URL == "" + + +@pytest.mark.asyncio +async def test_static_proxy_provider_parses_proxy_url(monkeypatch): + monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://user:p%40ss@example.com:8080") + + proxies = await StaticProxyProvider().get_proxy(1) + + assert len(proxies) == 1 + proxy = proxies[0] + assert proxy.ip == "example.com" + assert proxy.port == 8080 + assert proxy.user == "user" + assert proxy.password == "p@ss" + assert proxy.protocol == "http://" + assert proxy.expired_time_ts is not None + + +@pytest.mark.asyncio +async def test_static_proxy_provider_rejects_invalid_url(monkeypatch): + monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://your_home_domain:port") + + proxies = await StaticProxyProvider().get_proxy(1) + + assert proxies == [] + + +@pytest.mark.asyncio +async def test_static_proxy_pool_disables_validation(monkeypatch): + monkeypatch.setattr(config, "IP_PROXY_PROVIDER_NAME", ProviderNameEnum.STATIC_PROVIDER.value) + monkeypatch.setattr(config, "STATIC_PROXY_URL", "https://example.com:8443") + + pool = await create_ip_pool(ip_pool_count=2, enable_validate_ip=True) + + assert pool.enable_validate_ip is False + assert len(pool.proxy_list) == 1 + assert pool.proxy_list[0].protocol == "https://"