mirror of
https://github.com/NanmiCoder/MediaCrawler.git
synced 2026-06-03 08:27:26 +08:00
Merge pull request #900 from zanmeipaul/main
feat: 启动任务接口添加帖子/视频数量与评论数量覆盖支持
This commit is contained in:
@@ -18,7 +18,10 @@
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
MAX_API_LIMIT_COUNT = 10000
|
||||||
|
|
||||||
|
|
||||||
class PlatformEnum(str, Enum):
|
class PlatformEnum(str, Enum):
|
||||||
@@ -71,6 +74,8 @@ class CrawlerStartRequest(BaseModel):
|
|||||||
save_option: SaveDataOptionEnum = SaveDataOptionEnum.JSONL
|
save_option: SaveDataOptionEnum = SaveDataOptionEnum.JSONL
|
||||||
cookies: str = ""
|
cookies: str = ""
|
||||||
headless: bool = False
|
headless: bool = False
|
||||||
|
max_notes_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT)
|
||||||
|
max_comments_count: Optional[int] = Field(default=None, ge=1, le=MAX_API_LIMIT_COUNT)
|
||||||
|
|
||||||
|
|
||||||
class CrawlerStatusResponse(BaseModel):
|
class CrawlerStatusResponse(BaseModel):
|
||||||
|
|||||||
@@ -225,6 +225,12 @@ class CrawlerManager:
|
|||||||
cmd.extend(["--get_comment", "true" if config.enable_comments else "false"])
|
cmd.extend(["--get_comment", "true" if config.enable_comments else "false"])
|
||||||
cmd.extend(["--get_sub_comment", "true" if config.enable_sub_comments else "false"])
|
cmd.extend(["--get_sub_comment", "true" if config.enable_sub_comments else "false"])
|
||||||
|
|
||||||
|
if config.max_notes_count is not None:
|
||||||
|
cmd.extend(["--crawler_max_notes_count", str(config.max_notes_count)])
|
||||||
|
|
||||||
|
if config.max_comments_count is not None:
|
||||||
|
cmd.extend(["--max_comments_count_singlenotes", str(config.max_comments_count)])
|
||||||
|
|
||||||
if config.cookies:
|
if config.cookies:
|
||||||
cmd.extend(["--cookies", config.cookies])
|
cmd.extend(["--cookies", config.cookies])
|
||||||
|
|
||||||
|
|||||||
@@ -275,6 +275,14 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None):
|
|||||||
rich_help_panel="Comment Configuration",
|
rich_help_panel="Comment Configuration",
|
||||||
),
|
),
|
||||||
] = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES,
|
] = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES,
|
||||||
|
crawler_max_notes_count: Annotated[
|
||||||
|
int,
|
||||||
|
typer.Option(
|
||||||
|
"--crawler_max_notes_count",
|
||||||
|
help="Maximum number of videos/posts to crawl",
|
||||||
|
rich_help_panel="Basic Configuration",
|
||||||
|
),
|
||||||
|
] = config.CRAWLER_MAX_NOTES_COUNT,
|
||||||
max_concurrency_num: Annotated[
|
max_concurrency_num: Annotated[
|
||||||
int,
|
int,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@@ -312,10 +320,18 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None):
|
|||||||
str,
|
str,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
"--ip_proxy_provider_name",
|
"--ip_proxy_provider_name",
|
||||||
help="IP proxy provider name (kuaidaili | wandouhttp)",
|
help="IP proxy provider name (kuaidaili | wandouhttp | static)",
|
||||||
rich_help_panel="Proxy Configuration",
|
rich_help_panel="Proxy Configuration",
|
||||||
),
|
),
|
||||||
] = config.IP_PROXY_PROVIDER_NAME,
|
] = config.IP_PROXY_PROVIDER_NAME,
|
||||||
|
static_proxy_url: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option(
|
||||||
|
"--static_proxy_url",
|
||||||
|
help="Static proxy URL, for example http://user:password@host:port",
|
||||||
|
rich_help_panel="Proxy Configuration",
|
||||||
|
),
|
||||||
|
] = config.STATIC_PROXY_URL,
|
||||||
) -> SimpleNamespace:
|
) -> SimpleNamespace:
|
||||||
"""MediaCrawler 命令行入口"""
|
"""MediaCrawler 命令行入口"""
|
||||||
|
|
||||||
@@ -342,11 +358,13 @@ async def parse_cmd(argv: Optional[Sequence[str]] = None):
|
|||||||
config.SAVE_DATA_OPTION = save_data_option.value
|
config.SAVE_DATA_OPTION = save_data_option.value
|
||||||
config.COOKIES = cookies
|
config.COOKIES = cookies
|
||||||
config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = max_comments_count_singlenotes
|
config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = max_comments_count_singlenotes
|
||||||
|
config.CRAWLER_MAX_NOTES_COUNT = crawler_max_notes_count
|
||||||
config.MAX_CONCURRENCY_NUM = max_concurrency_num
|
config.MAX_CONCURRENCY_NUM = max_concurrency_num
|
||||||
config.SAVE_DATA_PATH = save_data_path
|
config.SAVE_DATA_PATH = save_data_path
|
||||||
config.ENABLE_IP_PROXY = enable_ip_proxy_value
|
config.ENABLE_IP_PROXY = enable_ip_proxy_value
|
||||||
config.IP_PROXY_POOL_COUNT = ip_proxy_pool_count
|
config.IP_PROXY_POOL_COUNT = ip_proxy_pool_count
|
||||||
config.IP_PROXY_PROVIDER_NAME = ip_proxy_provider_name
|
config.IP_PROXY_PROVIDER_NAME = ip_proxy_provider_name
|
||||||
|
config.STATIC_PROXY_URL = static_proxy_url
|
||||||
|
|
||||||
# Set platform-specific ID lists for detail/creator mode
|
# Set platform-specific ID lists for detail/creator mode
|
||||||
if specified_id_list:
|
if specified_id_list:
|
||||||
|
|||||||
@@ -37,7 +37,11 @@ ENABLE_IP_PROXY = False
|
|||||||
IP_PROXY_POOL_COUNT = 2
|
IP_PROXY_POOL_COUNT = 2
|
||||||
|
|
||||||
# Proxy IP provider name
|
# Proxy IP provider name
|
||||||
IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp
|
IP_PROXY_PROVIDER_NAME = "kuaidaili" # kuaidaili | wandouhttp | static
|
||||||
|
|
||||||
|
# Static proxy configuration (used when IP_PROXY_PROVIDER_NAME is set to "static")
|
||||||
|
# Format: "http://your_home_domain:port" or "http://user:password@your_home_domain:port"
|
||||||
|
STATIC_PROXY_URL = ""
|
||||||
|
|
||||||
# Setting to True will not open the browser (headless browser)
|
# Setting to True will not open the browser (headless browser)
|
||||||
# Setting False will open a browser
|
# Setting False will open a browser
|
||||||
|
|||||||
@@ -22,7 +22,9 @@
|
|||||||
# @Time : 2023/12/2 13:45
|
# @Time : 2023/12/2 13:45
|
||||||
# @Desc : IP proxy pool implementation
|
# @Desc : IP proxy pool implementation
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
@@ -150,9 +152,46 @@ class ProxyIpPool:
|
|||||||
await self.load_proxies()
|
await self.load_proxies()
|
||||||
|
|
||||||
|
|
||||||
|
class StaticProxyProvider(ProxyProvider):
|
||||||
|
async def get_proxy(self, num: int) -> List[IpInfoModel]:
|
||||||
|
proxy_url = getattr(config, "STATIC_PROXY_URL", "")
|
||||||
|
if not proxy_url:
|
||||||
|
utils.logger.warning("[StaticProxyProvider] STATIC_PROXY_URL is not configured!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(proxy_url)
|
||||||
|
scheme = parsed.scheme or "http"
|
||||||
|
if scheme not in {"http", "https"}:
|
||||||
|
utils.logger.error(f"[StaticProxyProvider] Unsupported proxy scheme: {scheme}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
ip = parsed.hostname or ""
|
||||||
|
port = parsed.port or (443 if scheme == "https" else 80)
|
||||||
|
if not ip:
|
||||||
|
utils.logger.error("[StaticProxyProvider] STATIC_PROXY_URL host is empty!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
IpInfoModel(
|
||||||
|
ip=ip,
|
||||||
|
port=port,
|
||||||
|
user=unquote(parsed.username or ""),
|
||||||
|
password=unquote(parsed.password or ""),
|
||||||
|
protocol=f"{scheme}://",
|
||||||
|
# Static proxy doesn't expire.
|
||||||
|
expired_time_ts=int(time.time()) + 99999999,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
utils.logger.error(f"[StaticProxyProvider] Parse static proxy url error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
IpProxyProvider: Dict[str, ProxyProvider] = {
|
IpProxyProvider: Dict[str, ProxyProvider] = {
|
||||||
ProviderNameEnum.KUAI_DAILI_PROVIDER.value: new_kuai_daili_proxy(),
|
ProviderNameEnum.KUAI_DAILI_PROVIDER.value: new_kuai_daili_proxy(),
|
||||||
ProviderNameEnum.WANDOU_HTTP_PROVIDER.value: new_wandou_http_proxy(),
|
ProviderNameEnum.WANDOU_HTTP_PROVIDER.value: new_wandou_http_proxy(),
|
||||||
|
ProviderNameEnum.STATIC_PROVIDER.value: StaticProxyProvider(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -163,9 +202,10 @@ async def create_ip_pool(ip_pool_count: int, enable_validate_ip: bool) -> ProxyI
|
|||||||
:param enable_validate_ip: Whether to enable IP proxy validation
|
:param enable_validate_ip: Whether to enable IP proxy validation
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
is_static = config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.STATIC_PROVIDER.value
|
||||||
pool = ProxyIpPool(
|
pool = ProxyIpPool(
|
||||||
ip_pool_count=ip_pool_count,
|
ip_pool_count=ip_pool_count,
|
||||||
enable_validate_ip=enable_validate_ip,
|
enable_validate_ip=False if is_static else enable_validate_ip,
|
||||||
ip_provider=IpProxyProvider.get(config.IP_PROXY_PROVIDER_NAME),
|
ip_provider=IpProxyProvider.get(config.IP_PROXY_PROVIDER_NAME),
|
||||||
)
|
)
|
||||||
await pool.load_proxies()
|
await pool.load_proxies()
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from pydantic import BaseModel, Field
|
|||||||
class ProviderNameEnum(Enum):
|
class ProviderNameEnum(Enum):
|
||||||
KUAI_DAILI_PROVIDER: str = "kuaidaili"
|
KUAI_DAILI_PROVIDER: str = "kuaidaili"
|
||||||
WANDOU_HTTP_PROVIDER: str = "wandouhttp"
|
WANDOU_HTTP_PROVIDER: str = "wandouhttp"
|
||||||
|
STATIC_PROVIDER: str = "static"
|
||||||
|
|
||||||
|
|
||||||
class IpInfoModel(BaseModel):
|
class IpInfoModel(BaseModel):
|
||||||
|
|||||||
137
tests/test_api_limits.py
Normal file
137
tests/test_api_limits.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import pytest
|
||||||
|
import config
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from cmd_arg import parse_cmd
|
||||||
|
from api.schemas import CrawlerStartRequest, PlatformEnum, LoginTypeEnum, CrawlerTypeEnum
|
||||||
|
from api.services.crawler_manager import CrawlerManager
|
||||||
|
from api.main import app
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cmd_arg_crawler_max_notes_count():
|
||||||
|
# Store original values
|
||||||
|
orig_notes = config.CRAWLER_MAX_NOTES_COUNT
|
||||||
|
orig_comments = config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parse_cmd([
|
||||||
|
"--platform", "xhs",
|
||||||
|
"--crawler_max_notes_count", "42",
|
||||||
|
"--max_comments_count_singlenotes", "24"
|
||||||
|
])
|
||||||
|
assert config.CRAWLER_MAX_NOTES_COUNT == 42
|
||||||
|
assert config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES == 24
|
||||||
|
finally:
|
||||||
|
config.CRAWLER_MAX_NOTES_COUNT = orig_notes
|
||||||
|
config.CRAWLER_MAX_COMMENTS_COUNT_SINGLENOTES = orig_comments
|
||||||
|
|
||||||
|
def test_crawler_manager_build_command():
|
||||||
|
cm = CrawlerManager()
|
||||||
|
|
||||||
|
# 1. No max limits passed in API request
|
||||||
|
req1 = CrawlerStartRequest(
|
||||||
|
platform=PlatformEnum.XHS,
|
||||||
|
login_type=LoginTypeEnum.QRCODE,
|
||||||
|
crawler_type=CrawlerTypeEnum.SEARCH,
|
||||||
|
keywords="test",
|
||||||
|
max_notes_count=None,
|
||||||
|
max_comments_count=None
|
||||||
|
)
|
||||||
|
cmd1 = cm._build_command(req1)
|
||||||
|
# Check that the custom arguments are NOT present
|
||||||
|
assert "--crawler_max_notes_count" not in cmd1
|
||||||
|
assert "--max_comments_count_singlenotes" not in cmd1
|
||||||
|
|
||||||
|
# 2. Both limits passed in API request
|
||||||
|
req2 = CrawlerStartRequest(
|
||||||
|
platform=PlatformEnum.XHS,
|
||||||
|
login_type=LoginTypeEnum.QRCODE,
|
||||||
|
crawler_type=CrawlerTypeEnum.SEARCH,
|
||||||
|
keywords="test",
|
||||||
|
max_notes_count=50,
|
||||||
|
max_comments_count=5
|
||||||
|
)
|
||||||
|
cmd2 = cm._build_command(req2)
|
||||||
|
# Check that they are correctly added
|
||||||
|
assert "--crawler_max_notes_count" in cmd2
|
||||||
|
idx_notes = cmd2.index("--crawler_max_notes_count")
|
||||||
|
assert cmd2[idx_notes + 1] == "50"
|
||||||
|
|
||||||
|
assert "--max_comments_count_singlenotes" in cmd2
|
||||||
|
idx_comments = cmd2.index("--max_comments_count_singlenotes")
|
||||||
|
assert cmd2[idx_comments + 1] == "5"
|
||||||
|
|
||||||
|
def test_api_start_crawler_with_limits():
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start:
|
||||||
|
mock_start.return_value = True
|
||||||
|
|
||||||
|
# Test case 1: with limits
|
||||||
|
response = client.post("/api/crawler/start", json={
|
||||||
|
"platform": "xhs",
|
||||||
|
"login_type": "qrcode",
|
||||||
|
"crawler_type": "search",
|
||||||
|
"keywords": "test",
|
||||||
|
"max_notes_count": 50,
|
||||||
|
"max_comments_count": 5
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok", "message": "Crawler started successfully"}
|
||||||
|
|
||||||
|
mock_start.assert_called_once()
|
||||||
|
called_request = mock_start.call_args[0][0]
|
||||||
|
assert called_request.platform == PlatformEnum.XHS
|
||||||
|
assert called_request.max_notes_count == 50
|
||||||
|
assert called_request.max_comments_count == 5
|
||||||
|
|
||||||
|
def test_api_start_crawler_without_limits():
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start:
|
||||||
|
mock_start.return_value = True
|
||||||
|
|
||||||
|
# Test case 2: without limits
|
||||||
|
response = client.post("/api/crawler/start", json={
|
||||||
|
"platform": "xhs",
|
||||||
|
"login_type": "qrcode",
|
||||||
|
"crawler_type": "search",
|
||||||
|
"keywords": "test"
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_start.assert_called_once()
|
||||||
|
called_request = mock_start.call_args[0][0]
|
||||||
|
assert called_request.platform == PlatformEnum.XHS
|
||||||
|
assert called_request.max_notes_count is None
|
||||||
|
assert called_request.max_comments_count is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("field_name", "value"),
|
||||||
|
[
|
||||||
|
("max_notes_count", 0),
|
||||||
|
("max_notes_count", -1),
|
||||||
|
("max_notes_count", 10001),
|
||||||
|
("max_comments_count", 0),
|
||||||
|
("max_comments_count", -1),
|
||||||
|
("max_comments_count", 10001),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_api_rejects_invalid_limits(field_name, value):
|
||||||
|
client = TestClient(app)
|
||||||
|
payload = {
|
||||||
|
"platform": "xhs",
|
||||||
|
"login_type": "qrcode",
|
||||||
|
"crawler_type": "search",
|
||||||
|
"keywords": "test",
|
||||||
|
field_name: value,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("api.routers.crawler.crawler_manager.start", new_callable=AsyncMock) as mock_start:
|
||||||
|
response = client.post("/api/crawler/start", json=payload)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
mock_start.assert_not_called()
|
||||||
49
tests/test_static_proxy_provider.py
Normal file
49
tests/test_static_proxy_provider.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import config
|
||||||
|
from proxy.proxy_ip_pool import StaticProxyProvider, create_ip_pool
|
||||||
|
from proxy.types import ProviderNameEnum
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_proxy_provider_remains_existing_provider():
|
||||||
|
assert config.IP_PROXY_PROVIDER_NAME == ProviderNameEnum.KUAI_DAILI_PROVIDER.value
|
||||||
|
assert config.IP_PROXY_POOL_COUNT == 2
|
||||||
|
assert config.STATIC_PROXY_URL == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_proxy_provider_parses_proxy_url(monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://user:p%40ss@example.com:8080")
|
||||||
|
|
||||||
|
proxies = await StaticProxyProvider().get_proxy(1)
|
||||||
|
|
||||||
|
assert len(proxies) == 1
|
||||||
|
proxy = proxies[0]
|
||||||
|
assert proxy.ip == "example.com"
|
||||||
|
assert proxy.port == 8080
|
||||||
|
assert proxy.user == "user"
|
||||||
|
assert proxy.password == "p@ss"
|
||||||
|
assert proxy.protocol == "http://"
|
||||||
|
assert proxy.expired_time_ts is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_proxy_provider_rejects_invalid_url(monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "STATIC_PROXY_URL", "http://your_home_domain:port")
|
||||||
|
|
||||||
|
proxies = await StaticProxyProvider().get_proxy(1)
|
||||||
|
|
||||||
|
assert proxies == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_proxy_pool_disables_validation(monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "IP_PROXY_PROVIDER_NAME", ProviderNameEnum.STATIC_PROVIDER.value)
|
||||||
|
monkeypatch.setattr(config, "STATIC_PROXY_URL", "https://example.com:8443")
|
||||||
|
|
||||||
|
pool = await create_ip_pool(ip_pool_count=2, enable_validate_ip=True)
|
||||||
|
|
||||||
|
assert pool.enable_validate_ip is False
|
||||||
|
assert len(pool.proxy_list) == 1
|
||||||
|
assert pool.proxy_list[0].protocol == "https://"
|
||||||
Reference in New Issue
Block a user