diff --git a/src/crawlee/_utils/sitemap.py b/src/crawlee/_utils/sitemap.py index 95d1e26a5f..b90d2e6935 100644 --- a/src/crawlee/_utils/sitemap.py +++ b/src/crawlee/_utils/sitemap.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import re import zlib from codecs import getincrementaldecoder +from collections import defaultdict from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta @@ -16,6 +18,9 @@ from typing_extensions import NotRequired, override from yarl import URL +from crawlee._utils.web import is_status_code_successful +from crawlee.errors import ProxyError + if TYPE_CHECKING: from collections.abc import AsyncGenerator from xml.sax.xmlreader import AttributesImpl @@ -27,6 +32,8 @@ VALID_CHANGE_FREQS = {'always', 'hourly', 'daily', 'weekly', 'monthly', 'yearly', 'never'} SITEMAP_HEADERS = {'accept': 'text/plain, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8'} +SITEMAP_URL_PATTERN = re.compile(r'\/sitemap\.(?:xml|txt)(?:\.gz)?$', re.IGNORECASE) +COMMON_SITEMAP_PATHS = ['/sitemap.xml', '/sitemap.txt', '/sitemap_index.xml'] @dataclass() @@ -384,7 +391,7 @@ def urls(self) -> list[str]: @classmethod async def try_common_names(cls, url: str, http_client: HttpClient, proxy_info: ProxyInfo | None = None) -> Sitemap: base_url = URL(url) - sitemap_urls = [str(base_url.with_path('/sitemap.xml')), str(base_url.with_path('/sitemap.txt'))] + sitemap_urls = [str(base_url.with_path(path)) for path in COMMON_SITEMAP_PATHS] return await cls.load(sitemap_urls, http_client, proxy_info) @classmethod @@ -484,3 +491,140 @@ async def parse_sitemap( yield result else: logger.warning(f'Invalid source configuration: {source}') + + +async def _merge_async_generators(*generators: AsyncGenerator) -> AsyncGenerator: + queue: asyncio.Queue = asyncio.Queue() + + end_feed = object() + + async def feed(gen: AsyncGenerator) -> None: + try: + async for item in gen: + await queue.put(item) + except Exception: + logger.warning(f'Error in generator: {gen}', exc_info=True) + finally: + await queue.put(end_feed) + + tasks = [asyncio.create_task(feed(gen)) for gen in generators] + remaining_tasks = len(tasks) + + try: + while remaining_tasks > 0: + item = await queue.get() + if item is end_feed: + remaining_tasks -= 1 + else: + yield item + finally: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +async def _discover_for_hostname( + hostname: str, + hostname_urls: list[str], + *, + http_client: HttpClient, + proxy_info: ProxyInfo | None = None, + request_timeout: timedelta, + method_for_checking: Literal['HEAD', 'GET'] = 'HEAD', +) -> AsyncGenerator[str, None]: + # Import here to avoid circular imports. + from crawlee._utils.robots import RobotsTxtFile # noqa: PLC0415 + + domain_seen: set[str] = set() + hostname_urls = list(set(hostname_urls)) # Remove duplicates + + def _check_and_add(url: str) -> bool: + if url in domain_seen: + return False + domain_seen.add(url) + return True + + # Try getting sitemaps from robots.txt first + robots = await RobotsTxtFile.find(url=hostname_urls[0], http_client=http_client, proxy_info=proxy_info) + for sitemap_url in robots.get_sitemaps(): + if _check_and_add(sitemap_url): + yield sitemap_url + + # Check maybe provided URLs have sitemap url + matching_sitemap_urls = [url for url in hostname_urls if SITEMAP_URL_PATTERN.search(url)] + + if matching_sitemap_urls: + for sitemap_url in matching_sitemap_urls: + if _check_and_add(sitemap_url): + yield sitemap_url + else: + # Check common sitemap locations + base_url = URL(hostname_urls[0]) + for path in COMMON_SITEMAP_PATHS: + candidate = str(base_url.with_path(path)) + if candidate in domain_seen: + continue + try: + response = await http_client.send_request( + candidate, method=method_for_checking, proxy_info=proxy_info, timeout=request_timeout + ) + if is_status_code_successful(response.status_code) and _check_and_add(candidate): + yield candidate + except ProxyError: + logger.warning(f'Proxy error when checking {candidate} with sitemap discovery for {hostname}') + except asyncio.TimeoutError: + logger.warning(f'Timeout when checking {candidate} with sitemap discovery for {hostname}') + except Exception: + logger.warning(f'Error when checking {candidate} with sitemap discovery for {hostname}', exc_info=True) + + +async def discover_valid_sitemaps( + urls: list[str], + *, + http_client: HttpClient, + proxy_info: ProxyInfo | None = None, + request_timeout: timedelta = timedelta(seconds=20), + method_for_checking: Literal['HEAD', 'GET'] = 'HEAD', +) -> AsyncGenerator[str, None]: + """Discover related sitemaps for the given URLs. + + Args: + urls: List of URLs to discover sitemaps for. + http_client: `HttpClient` to use for making requests. + proxy_info: Proxy configuration to use for requests. + request_timeout: Timeout for each request when checking for sitemaps. + method_for_checking: HTTP method to use when checking for sitemap existence (HEAD or GET). + """ + # Use a set to track seen sitemap URLs and avoid duplicates + seen = set() + + grouped_urls = defaultdict(list) + for url in urls: + try: + hostname = URL(url).host + except ValueError: + logger.warning(f'Invalid URL {url} skipped') + continue + + if not hostname: + logger.warning(f'URL {url} without host skipped') + continue + + grouped_urls[hostname].append(url) + + generators = [ + _discover_for_hostname( + hostname, + hostname_urls, + http_client=http_client, + proxy_info=proxy_info, + request_timeout=request_timeout, + method_for_checking=method_for_checking, + ) + for hostname, hostname_urls in grouped_urls.items() + ] + + async for sitemap_url in _merge_async_generators(*generators): + if sitemap_url not in seen: + seen.add(sitemap_url) + yield sitemap_url diff --git a/src/crawlee/_utils/web.py b/src/crawlee/_utils/web.py index 2624383abf..ff00480a67 100644 --- a/src/crawlee/_utils/web.py +++ b/src/crawlee/_utils/web.py @@ -1,11 +1,18 @@ from __future__ import annotations +from http import HTTPStatus + def is_status_code_client_error(value: int) -> bool: """Return `True` for 4xx status codes, `False` otherwise.""" - return 400 <= value <= 499 # noqa: PLR2004 + return HTTPStatus.BAD_REQUEST <= value < HTTPStatus.INTERNAL_SERVER_ERROR def is_status_code_server_error(value: int) -> bool: """Return `True` for 5xx status codes, `False` otherwise.""" - return value >= 500 # noqa: PLR2004 + return value >= HTTPStatus.INTERNAL_SERVER_ERROR + + +def is_status_code_successful(value: int) -> bool: + """Return `True` for 2xx and 3xx status codes, `False` otherwise.""" + return HTTPStatus.OK <= value < HTTPStatus.BAD_REQUEST diff --git a/tests/unit/_utils/test_sitemap.py b/tests/unit/_utils/test_sitemap.py index 807090eaa4..5f2005ca16 100644 --- a/tests/unit/_utils/test_sitemap.py +++ b/tests/unit/_utils/test_sitemap.py @@ -1,11 +1,13 @@ import base64 import gzip from datetime import datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock from yarl import URL -from crawlee._utils.sitemap import Sitemap, SitemapUrl, parse_sitemap -from crawlee.http_clients._base import HttpClient +from crawlee._utils.sitemap import Sitemap, SitemapUrl, discover_valid_sitemaps, parse_sitemap +from crawlee.http_clients._base import HttpClient, HttpResponse BASIC_SITEMAP = """ @@ -46,6 +48,23 @@ } +def _make_mock_client(url_map: dict[str, tuple[int, bytes]]) -> AsyncMock: + async def send_request(url: str, **_kwargs: Any) -> HttpResponse: + status, body = 404, b'' + for pattern, (s, b) in url_map.items(): + if pattern in url: + status, body = s, b + break + response = MagicMock(spec=HttpResponse) + response.status_code = status + response.read = AsyncMock(return_value=body) + return response + + client = AsyncMock(spec=HttpClient) + client.send_request.side_effect = send_request + return client + + def compress_gzip(data: str) -> bytes: """Compress a string using gzip.""" return gzip.compress(data.encode()) @@ -246,3 +265,84 @@ async def test_sitemap_from_string() -> None: assert len(sitemap.urls) == 5 assert set(sitemap.urls) == BASIC_RESULTS + + +async def test_discover_sitemap_from_robots_txt() -> None: + """Sitemap URL found in robots.txt is yielded.""" + robots_content = b'User-agent: *\nSitemap: http://example.com/custom-sitemap.xml' + http_client = _make_mock_client({'robots.txt': (200, robots_content)}) + + urls = [url async for url in discover_valid_sitemaps(['http://example.com/page'], http_client=http_client)] + + assert urls == ['http://example.com/custom-sitemap.xml'] + + +async def test_discover_sitemap_from_common_paths() -> None: + """Sitemap is found at common paths when robots.txt has none.""" + http_client = _make_mock_client( + {'/sitemap.xml': (200, b''), '/sitemap.txt': (200, b''), '/sitemap_index.xml': (200, b'')} + ) + + urls = [url async for url in discover_valid_sitemaps(['http://example.com/page'], http_client=http_client)] + + assert urls == [ + 'http://example.com/sitemap.xml', + 'http://example.com/sitemap.txt', + 'http://example.com/sitemap_index.xml', + ] + + +async def test_discover_sitemap_from_input_url() -> None: + """Input URL that is already a sitemap is yielded directly without checking common paths.""" + http_client = _make_mock_client({'/sitemap.txt': (200, b'')}) + + urls = [url async for url in discover_valid_sitemaps(['http://example.com/sitemap.xml'], http_client=http_client)] + + assert urls == ['http://example.com/sitemap.xml'] + + +async def test_discover_sitemap_deduplication() -> None: + """Sitemap URL found in robots.txt is not yielded again from common paths check.""" + robots_content = b'User-agent: *\nSitemap: http://example.com/sitemap.xml' + http_client = _make_mock_client( + { + 'robots.txt': (200, robots_content), + '/sitemap.xml': (200, b''), + } + ) + + urls = [url async for url in discover_valid_sitemaps(['http://example.com/page'], http_client=http_client)] + + assert urls == ['http://example.com/sitemap.xml'] + + +async def test_discover_sitemaps_multiple_domains() -> None: + """Sitemaps from multiple domains are all discovered.""" + http_client = _make_mock_client( + { + 'domain-a.com/sitemap.xml': (200, b''), + 'domain-b.com/sitemap.xml': (200, b''), + } + ) + + urls = [ + url + async for url in discover_valid_sitemaps( + ['http://domain-a.com/page', 'http://domain-b.com/page'], + http_client=http_client, + ) + ] + + assert set(urls) == { + 'http://domain-a.com/sitemap.xml', + 'http://domain-b.com/sitemap.xml', + } + + +async def test_discover_sitemap_url_without_host_skipped() -> None: + """URLs without a host are skipped.""" + http_client = _make_mock_client({}) + + urls = [url async for url in discover_valid_sitemaps(['not-a-valid-url'], http_client=http_client)] + + assert urls == []