diff --git a/api/schemas/crawler.py b/api/schemas/crawler.py index f31ef82..6eb3b1b 100644 --- a/api/schemas/crawler.py +++ b/api/schemas/crawler.py @@ -18,7 +18,10 @@ from enum import Enum from typing import Optional, Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +MAX_API_LIMIT_COUNT = 10000 class PlatformEnum(str, Enum): @@ -71,6 +74,8 @@ class CrawlerStartRequest(BaseModel): save_option: SaveDataOptionEnum = SaveDataOptionEnum.JSONL cookies: str = "" headless: bool = False + max_notes_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT) + max_comments_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT) class CrawlerStatusResponse(BaseModel): diff --git a/api/services/crawler_manager.py b/api/services/crawler_manager.py index f0fb228..9af954b 100644 --- a/api/services/crawler_manager.py +++ b/api/services/crawler_manager.py @@ -225,6 +225,12 @@ class CrawlerManager: cmd.extend(["--get_comment", "true" if config.enable_comments else "false"]) cmd.extend(["--get_sub_comment", "true" if config.enable_sub_comments else "false"]) + if config.max_notes_count is not None: + cmd.extend(["--crawler_max_notes_count", str(config.max_notes_count)]) + + if config.max_comments_count is not None: + cmd.extend(["--max_comments_count_singlenotes", str(config.max_comments_count)]) + if config.cookies: cmd.extend(["--cookies", config.cookies]) diff --git a/cmd_arg/arg.py b/cmd_arg/arg.py index 86199db..074dd13 100644 --- a/cmd_arg/arg.py +++ b/cmd_arg/arg.py @@ -275,6 +275,14 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None): rich_help_panel="Comment Configuration", ), ] = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES, + crawler_max_notes_count: Annotated[ + int, + typer.Option( + "--crawler_max_notes_count", + help="Maximum number of videos/posts to crawl", + rich_help_panel="Basic Configuration", + ), + ] = config.CRAWLER_MAX_NOTES_COUNT, max_concurrency_num: Annotated[ int, typer.Option( @@ -312,10 +320,18 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None): str, typer.Option( "--ip_proxy_provider_name", - help="IP proxy provider name (kuaidaili | wandouhttp)", + help="IP proxy provider name (kuaidaili | wandouhttp | static)", rich_help_panel="Proxy Configuration", ), ] = config.IP_PROXY_PROVIDER_NAME, + static_proxy_url: Annotated[ + str, + typer.Option( + "--static_proxy_url", + help="Static proxy URL, for example http://user:password@host:port", + rich_help_panel="Proxy Configuration", + ), + ] = config.STATIC_PROXY_URL, ) -> SimpleNamespace: """MediaCrawler 命令行入口""" @@ -342,11 +358,13 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None): config.SAVE_DATA_OPTION = save_data_option.value config.COOKIES = cookies config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = max_comments_count_singlenotes + config.CRAWLER_MAX_NOTES_COUNT = crawler_max_notes_count config.MAX_CONCURRENCY_NUM = max_concurrency_num config.SAVE_DATA_PATH = save_data_path config.ENABLE_IP_PROXY = enable_ip_proxy_value config.IP_PROXY_POOL_COUNT = ip_proxy_pool_count config.IP_PROXY_PROVIDER_NAME = ip_proxy_provider_name + config.STATIC_PROXY_URL = static_proxy_url # Set platform-specific ID lists for detail/creator mode if specified_id_list: diff --git a/config/base_config.py b/config/base_config.py index 49c35ce..28a852e 100644 --- a/config/base_config.py +++ b/config/base_config.py @@ -37,7 +37,11 @@ ENABLE_IP_PROXY = False IP_PROXY_POOL_COUNT = 2 # Proxy IP provider name -IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp +IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp | static + +# Static proxy configuration (used when IP_PROXY_PROVIDER_NAME is set to "static") +# Format: "http://your_home_domain:port" or "http://user:password@your_home_domain:port" +STATIC_PROXY_URL = "" # Setting to True will not open the browser (headless browser) # Setting False will open a browser diff --git a/proxy/proxy_ip_pool.py b/proxy/proxy_ip_pool.py index 8f4a129..8041d38 100644 --- a/proxy/proxy_ip_pool.py +++ b/proxy/proxy_ip_pool.py @@ -22,7 +22,9 @@ # @Time : 2023/12/2 13:45 # @Desc : IP proxy pool implementation import random +import time from typing import Dict, List +from urllib.parse import unquote, urlparse import httpx from tenacity import retry, stop_after_attempt, wait_fixed @@ -150,9 +152,46 @@ class ProxyIpPool: await self.load_proxies() +class StaticProxyProvider(ProxyProvider): + async def get_proxy(self, num: int) -> List[IpInfoModel]: + proxy_url = getattr(config, "STATIC_PROXY_URL", "") + if not proxy_url: + utils.logger.warning("[StaticProxyProvider] STATIC_PROXY_URL is not configured!") + return [] + + try: + parsed = urlparse(proxy_url) + scheme = parsed.scheme or "http" + if scheme not in {"http", "https"}: + utils.logger.error(f"[StaticProxyProvider] Unsupported proxy scheme: {scheme}") + return [] + + ip = parsed.hostname or "" + port = parsed.port or (443 if scheme == "https" else 80) + if not ip: + utils.logger.error("[StaticProxyProvider] STATIC_PROXY_URL host is empty!") + return [] + + return [ + IpInfoModel( + ip=ip, + port=port, + user=unquote(parsed.username or ""), + password=unquote(parsed.password or ""), + protocol=f"{scheme}://", + # Static proxy doesn't expire. + expired_time_ts=int(time.time()) + 99999999, + ) + ] + except Exception as e: + utils.logger.error(f"[StaticProxyProvider] Parse static proxy url error: {e}") + return [] + + IpProxyProvider: Dict[str, ProxyProvider] = { ProviderNameEnum.KUAI_DAILI_PROVIDER.value: new_kuai_daili_proxy(), ProviderNameEnum.WANDOU_HTTP_PROVIDER.value: new_wandou_http_proxy(), + ProviderNameEnum.STATIC_PROVIDER.value: StaticProxyProvider(), } @@ -163,9 +202,10 @@ async def create_ip_pool(ip_pool_count: int, enable_validate_ip: bool) -> ProxyI :param enable_validate_ip: Whether to enable IP proxy validation :return: """ + is_static = config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.STATIC_PROVIDER.value pool = ProxyIpPool( ip_pool_count=ip_pool_count, - enable_validate_ip=enable_validate_ip, + enable_validate_ip=False if is_static else enable_validate_ip, ip_provider=IpProxyProvider.get(config.IP_PROXY_PROVIDER_NAME), ) await pool.load_proxies() diff --git a/proxy/types.py b/proxy/types.py index e203141..57f3fe1 100644 --- a/proxy/types.py +++ b/proxy/types.py @@ -32,6 +32,7 @@ from pydantic import BaseModel, Field class ProviderNameEnum(Enum): KUAI_DAILI_PROVIDER: str = "kuaidaili" WANDOU_HTTP_PROVIDER: str = "wandouhttp" + STATIC_PROVIDER: str = "static" class IpInfoModel(BaseModel): diff --git a/tests/test_api_limits.py b/tests/test_api_limits.py new file mode 100644 index 0000000..0cb65bc --- /dev/null +++ b/tests/test_api_limits.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +import pytest +import config +from unittest.mock import AsyncMock, patch +from fastapi.testclient import TestClient +from cmd_arg import parse_cmd +from api.schemas import CrawlerStartRequest, PlatformEnum, LoginTypeEnum, CrawlerTypeEnum +from api.services.crawler_manager import CrawlerManager +from api.main import app + +@pytest.mark.asyncio +async def test_cmd_arg_crawler_max_notes_count(): + # Store original values + orig_notes = config.CRAWLER_MAX_NOTES_COUNT + orig_comments = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES + + try: + await parse_cmd([ + "--platform", "xhs", + "--crawler_max_notes_count", "42", + "--max_comments_count_singlenotes", "24" + ]) + assert config.CRAWLER_MAX_NOTES_COUNT == 42 + assert config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES == 24 + finally: + config.CRAWLER_MAX_NOTES_COUNT = orig_notes + config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = orig_comments + +def test_crawler_manager_build_command(): + cm = CrawlerManager() + + # 1. No max limits passed in API request + req1 = CrawlerStartRequest( + platform=PlatformEnum.XHS, + login_type=LoginTypeEnum.QRCODE, + crawler_type=CrawlerTypeEnum.SEARCH, + keywords="test", + max_notes_count=None, + max_comments_count=None + ) + cmd1 = cm._build_command(req1) + # Check that the custom arguments are NOT present + assert "--crawler_max_notes_count" not in cmd1 + assert "--max_comments_count_singlenotes" not in cmd1 + + # 2. Both limits passed in API request + req2 = CrawlerStartRequest( + platform=PlatformEnum.XHS, + login_type=LoginTypeEnum.QRCODE, + crawler_type=CrawlerTypeEnum.SEARCH, + keywords="test", + max_notes_count=50, + max_comments_count=5 + ) + cmd2 = cm._build_command(req2) + # Check that they are correctly added + assert "--crawler_max_notes_count" in cmd2 + idx_notes = cmd2.index("--crawler_max_notes_count") + assert cmd2[idx_notes + 1] == "50" + + assert "--max_comments_count_singlenotes" in cmd2 + idx_comments = cmd2.index("--max_comments_count_singlenotes") + assert cmd2[idx_comments + 1] == "5" + +def test_api_start_crawler_with_limits(): + client = TestClient(app) + + with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start: + mock_start.return_value = True + + # Test case 1: with limits + response = client.post("/api/crawler/start", json={ + "platform": "xhs", + "login_type": "qrcode", + "crawler_type": "search", + "keywords": "test", + "max_notes_count": 50, + "max_comments_count": 5 + }) + + assert response.status_code == 200 + assert response.json() == {"status": "ok", "message": "Crawler started successfully"} + + mock_start.assert_called_once() + called_request = mock_start.call_args[0][0] + assert called_request.platform == PlatformEnum.XHS + assert called_request.max_notes_count == 50 + assert called_request.max_comments_count == 5 + +def test_api_start_crawler_without_limits(): + client = TestClient(app) + + with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start: + mock_start.return_value = True + + # Test case 2: without limits + response = client.post("/api/crawler/start", json={ + "platform": "xhs", + "login_type": "qrcode", + "crawler_type": "search", + "keywords": "test" + }) + + assert response.status_code == 200 + mock_start.assert_called_once() + called_request = mock_start.call_args[0][0] + assert called_request.platform == PlatformEnum.XHS + assert called_request.max_notes_count is None + assert called_request.max_comments_count is None + + +@pytest.mark.parametrize( + ("field_name", "value"), + [ + ("max_notes_count", 0), + ("max_notes_count", -1), + ("max_notes_count", 10001), + ("max_comments_count", 0), + ("max_comments_count", -1), + ("max_comments_count", 10001), + ], +) +def test_api_rejects_invalid_limits(field_name, value): + client = TestClient(app) + payload = { + "platform": "xhs", + "login_type": "qrcode", + "crawler_type": "search", + "keywords": "test", + field_name: value, + } + + with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start: + response = client.post("/api/crawler/start", json=payload) + + assert response.status_code == 422 + mock_start.assert_not_called() diff --git a/tests/test_static_proxy_provider.py b/tests/test_static_proxy_provider.py new file mode 100644 index 0000000..c88647c --- /dev/null +++ b/tests/test_static_proxy_provider.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import pytest + +import config +from proxy.proxy_ip_pool import StaticProxyProvider, create_ip_pool +from proxy.types import ProviderNameEnum + + +def test_default_proxy_provider_remains_existing_provider(): + assert config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.KUAI_DAILI_PROVIDER.value + assert config.IP_PROXY_POOL_COUNT == 2 + assert config.STATIC_PROXY_URL == "" + + +@pytest.mark.asyncio +async def test_static_proxy_provider_parses_proxy_url(monkeypatch): + monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://user:p%40ss@example.com:8080") + + proxies = await StaticProxyProvider().get_proxy(1) + + assert len(proxies) == 1 + proxy = proxies[0] + assert proxy.ip == "example.com" + assert proxy.port == 8080 + assert proxy.user == "user" + assert proxy.password == "p@ss" + assert proxy.protocol == "http://" + assert proxy.expired_time_ts is not None + + +@pytest.mark.asyncio +async def test_static_proxy_provider_rejects_invalid_url(monkeypatch): + monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://your_home_domain:port") + + proxies = await StaticProxyProvider().get_proxy(1) + + assert proxies == [] + + +@pytest.mark.asyncio +async def test_static_proxy_pool_disables_validation(monkeypatch): + monkeypatch.setattr(config, "IP_PROXY_PROVIDER_NAME", ProviderNameEnum.STATIC_PROVIDER.value) + monkeypatch.setattr(config, "STATIC_PROXY_URL", "https://example.com:8443") + + pool = await create_ip_pool(ip_pool_count=2, enable_validate_ip=True) + + assert pool.enable_validate_ip is False + assert len(pool.proxy_list) == 1 + assert pool.proxy_list[0].protocol == "https://"