Merge pull request #900 from zanmeipaul/main

feat: 启动任务接口添加帖子/视频数量与评论数量覆盖支持
This commit is contained in:
程序员阿江-Relakkes
2026-05-29 21:34:07 +08:00
committed by GitHub
8 changed files with 264 additions and 4 deletions

View File

@@ -18,7 +18,10 @@
from enum import Enum from enum import Enum
from typing import Optional, Literal from typing import Optional, Literal
from pydantic import BaseModel from pydantic import BaseModel, Field
MAX_API_LIMIT_COUNT = 10000
class PlatformEnum(str, Enum): class PlatformEnum(str, Enum):
@@ -71,6 +74,8 @@ class CrawlerStartRequest(BaseModel):
save_option: SaveDataOptionEnum = SaveDataOptionEnum.JSONL save_option: SaveDataOptionEnum = SaveDataOptionEnum.JSONL
cookies: str = "" cookies: str = ""
headless: bool = False headless: bool = False
max_notes_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT)
max_comments_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT)
class CrawlerStatusResponse(BaseModel): class CrawlerStatusResponse(BaseModel):

View File

@@ -225,6 +225,12 @@ class CrawlerManager:
cmd.extend(["--get_comment", "true" if config.enable_comments else "false"]) cmd.extend(["--get_comment", "true" if config.enable_comments else "false"])
cmd.extend(["--get_sub_comment", "true" if config.enable_sub_comments else "false"]) cmd.extend(["--get_sub_comment", "true" if config.enable_sub_comments else "false"])
if config.max_notes_count is not None:
cmd.extend(["--crawler_max_notes_count", str(config.max_notes_count)])
if config.max_comments_count is not None:
cmd.extend(["--max_comments_count_singlenotes", str(config.max_comments_count)])
if config.cookies: if config.cookies:
cmd.extend(["--cookies", config.cookies]) cmd.extend(["--cookies", config.cookies])

View File

@@ -275,6 +275,14 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None):
rich_help_panel="Comment Configuration", rich_help_panel="Comment Configuration",
), ),
] = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES, ] = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES,
crawler_max_notes_count: Annotated[
int,
typer.Option(
"--crawler_max_notes_count",
help="Maximum number of videos/posts to crawl",
rich_help_panel="Basic Configuration",
),
] = config.CRAWLER_MAX_NOTES_COUNT,
max_concurrency_num: Annotated[ max_concurrency_num: Annotated[
int, int,
typer.Option( typer.Option(
@@ -312,10 +320,18 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None):
str, str,
typer.Option( typer.Option(
"--ip_proxy_provider_name", "--ip_proxy_provider_name",
help="IP proxy provider name (kuaidaili | wandouhttp)", help="IP proxy provider name (kuaidaili | wandouhttp | static)",
rich_help_panel="Proxy Configuration", rich_help_panel="Proxy Configuration",
), ),
] = config.IP_PROXY_PROVIDER_NAME, ] = config.IP_PROXY_PROVIDER_NAME,
static_proxy_url: Annotated[
str,
typer.Option(
"--static_proxy_url",
help="Static proxy URL, for example http://user:password@host:port",
rich_help_panel="Proxy Configuration",
),
] = config.STATIC_PROXY_URL,
) -> SimpleNamespace: ) -> SimpleNamespace:
"""MediaCrawler 命令行入口""" """MediaCrawler 命令行入口"""
@@ -342,11 +358,13 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None):
config.SAVE_DATA_OPTION = save_data_option.value config.SAVE_DATA_OPTION = save_data_option.value
config.COOKIES = cookies config.COOKIES = cookies
config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = max_comments_count_singlenotes config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = max_comments_count_singlenotes
config.CRAWLER_MAX_NOTES_COUNT = crawler_max_notes_count
config.MAX_CONCURRENCY_NUM = max_concurrency_num config.MAX_CONCURRENCY_NUM = max_concurrency_num
config.SAVE_DATA_PATH = save_data_path config.SAVE_DATA_PATH = save_data_path
config.ENABLE_IP_PROXY = enable_ip_proxy_value config.ENABLE_IP_PROXY = enable_ip_proxy_value
config.IP_PROXY_POOL_COUNT = ip_proxy_pool_count config.IP_PROXY_POOL_COUNT = ip_proxy_pool_count
config.IP_PROXY_PROVIDER_NAME = ip_proxy_provider_name config.IP_PROXY_PROVIDER_NAME = ip_proxy_provider_name
config.STATIC_PROXY_URL = static_proxy_url
# Set platform-specific ID lists for detail/creator mode # Set platform-specific ID lists for detail/creator mode
if specified_id_list: if specified_id_list:

View File

@@ -37,7 +37,11 @@ ENABLE_IP_PROXY = False
IP_PROXY_POOL_COUNT = 2 IP_PROXY_POOL_COUNT = 2
# Proxy IP provider name # Proxy IP provider name
IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp | static
# Static proxy configuration (used when IP_PROXY_PROVIDER_NAME is set to "static")
# Format: "http://your_home_domain:port" or "http://user:password@your_home_domain:port"
STATIC_PROXY_URL = ""
# Setting to True will not open the browser (headless browser) # Setting to True will not open the browser (headless browser)
# Setting False will open a browser # Setting False will open a browser

View File

@@ -22,7 +22,9 @@
# @Time : 2023/12/2 13:45 # @Time : 2023/12/2 13:45
# @Desc : IP proxy pool implementation # @Desc : IP proxy pool implementation
import random import random
import time
from typing import Dict, List from typing import Dict, List
from urllib.parse import unquote, urlparse
import httpx import httpx
from tenacity import retry, stop_after_attempt, wait_fixed from tenacity import retry, stop_after_attempt, wait_fixed
@@ -150,9 +152,46 @@ class ProxyIpPool:
await self.load_proxies() await self.load_proxies()
class StaticProxyProvider(ProxyProvider):
async def get_proxy(self, num: int) -> List[IpInfoModel]:
proxy_url = getattr(config, "STATIC_PROXY_URL", "")
if not proxy_url:
utils.logger.warning("[StaticProxyProvider] STATIC_PROXY_URL is not configured!")
return []
try:
parsed = urlparse(proxy_url)
scheme = parsed.scheme or "http"
if scheme not in {"http", "https"}:
utils.logger.error(f"[StaticProxyProvider] Unsupported proxy scheme: {scheme}")
return []
ip = parsed.hostname or ""
port = parsed.port or (443 if scheme == "https" else 80)
if not ip:
utils.logger.error("[StaticProxyProvider] STATIC_PROXY_URL host is empty!")
return []
return [
IpInfoModel(
ip=ip,
port=port,
user=unquote(parsed.username or ""),
password=unquote(parsed.password or ""),
protocol=f"{scheme}://",
# Static proxy doesn't expire.
expired_time_ts=int(time.time()) + 99999999,
)
]
except Exception as e:
utils.logger.error(f"[StaticProxyProvider] Parse static proxy url error: {e}")
return []
IpProxyProvider: Dict[str, ProxyProvider] = { IpProxyProvider: Dict[str, ProxyProvider] = {
ProviderNameEnum.KUAI_DAILI_PROVIDER.value: new_kuai_daili_proxy(), ProviderNameEnum.KUAI_DAILI_PROVIDER.value: new_kuai_daili_proxy(),
ProviderNameEnum.WANDOU_HTTP_PROVIDER.value: new_wandou_http_proxy(), ProviderNameEnum.WANDOU_HTTP_PROVIDER.value: new_wandou_http_proxy(),
ProviderNameEnum.STATIC_PROVIDER.value: StaticProxyProvider(),
} }
@@ -163,9 +202,10 @@ async def create_ip_pool(ip_pool_count: int, enable_validate_ip: bool) -> ProxyI
:param enable_validate_ip: Whether to enable IP proxy validation :param enable_validate_ip: Whether to enable IP proxy validation
:return: :return:
""" """
is_static = config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.STATIC_PROVIDER.value
pool = ProxyIpPool( pool = ProxyIpPool(
ip_pool_count=ip_pool_count, ip_pool_count=ip_pool_count,
enable_validate_ip=enable_validate_ip, enable_validate_ip=False if is_static else enable_validate_ip,
ip_provider=IpProxyProvider.get(config.IP_PROXY_PROVIDER_NAME), ip_provider=IpProxyProvider.get(config.IP_PROXY_PROVIDER_NAME),
) )
await pool.load_proxies() await pool.load_proxies()

View File

@@ -32,6 +32,7 @@ from pydantic import BaseModel, Field
class ProviderNameEnum(Enum): class ProviderNameEnum(Enum):
KUAI_DAILI_PROVIDER: str = "kuaidaili" KUAI_DAILI_PROVIDER: str = "kuaidaili"
WANDOU_HTTP_PROVIDER: str = "wandouhttp" WANDOU_HTTP_PROVIDER: str = "wandouhttp"
STATIC_PROVIDER: str = "static"
class IpInfoModel(BaseModel): class IpInfoModel(BaseModel):

137
tests/test_api_limits.py Normal file
View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
import pytest
import config
from unittest.mock import AsyncMock, patch
from fastapi.testclient import TestClient
from cmd_arg import parse_cmd
from api.schemas import CrawlerStartRequest, PlatformEnum, LoginTypeEnum, CrawlerTypeEnum
from api.services.crawler_manager import CrawlerManager
from api.main import app
@pytest.mark.asyncio
async def test_cmd_arg_crawler_max_notes_count():
# Store original values
orig_notes = config.CRAWLER_MAX_NOTES_COUNT
orig_comments = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES
try:
await parse_cmd([
"--platform", "xhs",
"--crawler_max_notes_count", "42",
"--max_comments_count_singlenotes", "24"
])
assert config.CRAWLER_MAX_NOTES_COUNT == 42
assert config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES == 24
finally:
config.CRAWLER_MAX_NOTES_COUNT = orig_notes
config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = orig_comments
def test_crawler_manager_build_command():
cm = CrawlerManager()
# 1. No max limits passed in API request
req1 = CrawlerStartRequest(
platform=PlatformEnum.XHS,
login_type=LoginTypeEnum.QRCODE,
crawler_type=CrawlerTypeEnum.SEARCH,
keywords="test",
max_notes_count=None,
max_comments_count=None
)
cmd1 = cm._build_command(req1)
# Check that the custom arguments are NOT present
assert "--crawler_max_notes_count" not in cmd1
assert "--max_comments_count_singlenotes" not in cmd1
# 2. Both limits passed in API request
req2 = CrawlerStartRequest(
platform=PlatformEnum.XHS,
login_type=LoginTypeEnum.QRCODE,
crawler_type=CrawlerTypeEnum.SEARCH,
keywords="test",
max_notes_count=50,
max_comments_count=5
)
cmd2 = cm._build_command(req2)
# Check that they are correctly added
assert "--crawler_max_notes_count" in cmd2
idx_notes = cmd2.index("--crawler_max_notes_count")
assert cmd2[idx_notes + 1] == "50"
assert "--max_comments_count_singlenotes" in cmd2
idx_comments = cmd2.index("--max_comments_count_singlenotes")
assert cmd2[idx_comments + 1] == "5"
def test_api_start_crawler_with_limits():
client = TestClient(app)
with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start:
mock_start.return_value = True
# Test case 1: with limits
response = client.post("/api/crawler/start", json={
"platform": "xhs",
"login_type": "qrcode",
"crawler_type": "search",
"keywords": "test",
"max_notes_count": 50,
"max_comments_count": 5
})
assert response.status_code == 200
assert response.json() == {"status": "ok", "message": "Crawler started successfully"}
mock_start.assert_called_once()
called_request = mock_start.call_args[0][0]
assert called_request.platform == PlatformEnum.XHS
assert called_request.max_notes_count == 50
assert called_request.max_comments_count == 5
def test_api_start_crawler_without_limits():
client = TestClient(app)
with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start:
mock_start.return_value = True
# Test case 2: without limits
response = client.post("/api/crawler/start", json={
"platform": "xhs",
"login_type": "qrcode",
"crawler_type": "search",
"keywords": "test"
})
assert response.status_code == 200
mock_start.assert_called_once()
called_request = mock_start.call_args[0][0]
assert called_request.platform == PlatformEnum.XHS
assert called_request.max_notes_count is None
assert called_request.max_comments_count is None
@pytest.mark.parametrize(
("field_name", "value"),
[
("max_notes_count", 0),
("max_notes_count", -1),
("max_notes_count", 10001),
("max_comments_count", 0),
("max_comments_count", -1),
("max_comments_count", 10001),
],
)
def test_api_rejects_invalid_limits(field_name, value):
client = TestClient(app)
payload = {
"platform": "xhs",
"login_type": "qrcode",
"crawler_type": "search",
"keywords": "test",
field_name: value,
}
with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start:
response = client.post("/api/crawler/start", json=payload)
assert response.status_code == 422
mock_start.assert_not_called()

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
import pytest
import config
from proxy.proxy_ip_pool import StaticProxyProvider, create_ip_pool
from proxy.types import ProviderNameEnum
def test_default_proxy_provider_remains_existing_provider():
assert config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.KUAI_DAILI_PROVIDER.value
assert config.IP_PROXY_POOL_COUNT == 2
assert config.STATIC_PROXY_URL == ""
@pytest.mark.asyncio
async def test_static_proxy_provider_parses_proxy_url(monkeypatch):
monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://user:p%40ss@example.com:8080")
proxies = await StaticProxyProvider().get_proxy(1)
assert len(proxies) == 1
proxy = proxies[0]
assert proxy.ip == "example.com"
assert proxy.port == 8080
assert proxy.user == "user"
assert proxy.password == "p@ss"
assert proxy.protocol == "http://"
assert proxy.expired_time_ts is not None
@pytest.mark.asyncio
async def test_static_proxy_provider_rejects_invalid_url(monkeypatch):
monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://your_home_domain:port")
proxies = await StaticProxyProvider().get_proxy(1)
assert proxies == []
@pytest.mark.asyncio
async def test_static_proxy_pool_disables_validation(monkeypatch):
monkeypatch.setattr(config, "IP_PROXY_PROVIDER_NAME", ProviderNameEnum.STATIC_PROVIDER.value)
monkeypatch.setattr(config, "STATIC_PROXY_URL", "https://example.com:8443")
pool = await create_ip_pool(ip_pool_count=2, enable_validate_ip=True)
assert pool.enable_validate_ip is False
assert len(pool.proxy_list) == 1
assert pool.proxy_list[0].protocol == "https://"