Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 145 additions & 1 deletion src/crawlee/_utils/sitemap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import re
import zlib
from codecs import getincrementaldecoder
from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -16,6 +18,9 @@
from typing_extensions import NotRequired, override
from yarl import URL

from crawlee._utils.web import is_status_code_successful
from crawlee.errors import ProxyError

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from xml.sax.xmlreader import AttributesImpl
Expand All @@ -27,6 +32,8 @@

VALID_CHANGE_FREQS = {'always', 'hourly', 'daily', 'weekly', 'monthly', 'yearly', 'never'}
SITEMAP_HEADERS = {'accept': 'text/plain, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8'}
SITEMAP_URL_PATTERN = re.compile(r'\/sitemap\.(?:xml|txt)(?:\.gz)?$', re.IGNORECASE)
COMMON_SITEMAP_PATHS = ['/sitemap.xml', '/sitemap.txt', '/sitemap_index.xml']


@dataclass()
Expand Down Expand Up @@ -384,7 +391,7 @@ def urls(self) -> list[str]:
@classmethod
async def try_common_names(cls, url: str, http_client: HttpClient, proxy_info: ProxyInfo | None = None) -> Sitemap:
base_url = URL(url)
sitemap_urls = [str(base_url.with_path('/sitemap.xml')), str(base_url.with_path('/sitemap.txt'))]
sitemap_urls = [str(base_url.with_path(path)) for path in COMMON_SITEMAP_PATHS]
return await cls.load(sitemap_urls, http_client, proxy_info)

@classmethod
Expand Down Expand Up @@ -484,3 +491,140 @@ async def parse_sitemap(
yield result
else:
logger.warning(f'Invalid source configuration: {source}')


async def _merge_async_generators(*generators: AsyncGenerator) -> AsyncGenerator:
queue: asyncio.Queue = asyncio.Queue()

end_feed = object()

async def feed(gen: AsyncGenerator) -> None:
try:
async for item in gen:
await queue.put(item)
except Exception:
logger.warning(f'Error in generator: {gen}', exc_info=True)
finally:
await queue.put(end_feed)

tasks = [asyncio.create_task(feed(gen)) for gen in generators]
remaining_tasks = len(tasks)

try:
while remaining_tasks > 0:
item = await queue.get()
if item is end_feed:
remaining_tasks -= 1
else:
yield item
finally:
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def _discover_for_hostname(
hostname: str,
hostname_urls: list[str],
*,
http_client: HttpClient,
proxy_info: ProxyInfo | None = None,
request_timeout: timedelta,
method_for_checking: Literal['HEAD', 'GET'] = 'HEAD',
) -> AsyncGenerator[str, None]:
# Import here to avoid circular imports.
from crawlee._utils.robots import RobotsTxtFile # noqa: PLC0415

domain_seen: set[str] = set()
hostname_urls = list(set(hostname_urls)) # Remove duplicates

def _check_and_add(url: str) -> bool:
if url in domain_seen:
return False
domain_seen.add(url)
return True

# Try getting sitemaps from robots.txt first
robots = await RobotsTxtFile.find(url=hostname_urls[0], http_client=http_client, proxy_info=proxy_info)
for sitemap_url in robots.get_sitemaps():
if _check_and_add(sitemap_url):
yield sitemap_url

# Check maybe provided URLs have sitemap url
matching_sitemap_urls = [url for url in hostname_urls if SITEMAP_URL_PATTERN.search(url)]

if matching_sitemap_urls:
for sitemap_url in matching_sitemap_urls:
if _check_and_add(sitemap_url):
yield sitemap_url
else:
# Check common sitemap locations
base_url = URL(hostname_urls[0])
for path in COMMON_SITEMAP_PATHS:
candidate = str(base_url.with_path(path))
if candidate in domain_seen:
continue
try:
response = await http_client.send_request(
candidate, method=method_for_checking, proxy_info=proxy_info, timeout=request_timeout
)
if is_status_code_successful(response.status_code) and _check_and_add(candidate):
yield candidate
except ProxyError:
logger.warning(f'Proxy error when checking {candidate} with sitemap discovery for {hostname}')
except asyncio.TimeoutError:
logger.warning(f'Timeout when checking {candidate} with sitemap discovery for {hostname}')
except Exception:
logger.warning(f'Error when checking {candidate} with sitemap discovery for {hostname}', exc_info=True)


async def discover_valid_sitemaps(
urls: list[str],
*,
http_client: HttpClient,
proxy_info: ProxyInfo | None = None,
request_timeout: timedelta = timedelta(seconds=20),
method_for_checking: Literal['HEAD', 'GET'] = 'HEAD',
) -> AsyncGenerator[str, None]:
"""Discover related sitemaps for the given URLs.

Args:
urls: List of URLs to discover sitemaps for.
http_client: `HttpClient` to use for making requests.
proxy_info: Proxy configuration to use for requests.
request_timeout: Timeout for each request when checking for sitemaps.
method_for_checking: HTTP method to use when checking for sitemap existence (HEAD or GET).
"""
# Use a set to track seen sitemap URLs and avoid duplicates
seen = set()

grouped_urls = defaultdict(list)
for url in urls:
try:
hostname = URL(url).host
except ValueError:
logger.warning(f'Invalid URL {url} skipped')
continue

if not hostname:
logger.warning(f'URL {url} without host skipped')
continue

grouped_urls[hostname].append(url)

generators = [
_discover_for_hostname(
hostname,
hostname_urls,
http_client=http_client,
proxy_info=proxy_info,
request_timeout=request_timeout,
method_for_checking=method_for_checking,
)
for hostname, hostname_urls in grouped_urls.items()
]

async for sitemap_url in _merge_async_generators(*generators):
if sitemap_url not in seen:
seen.add(sitemap_url)
yield sitemap_url
11 changes: 9 additions & 2 deletions src/crawlee/_utils/web.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

from http import HTTPStatus


def is_status_code_client_error(value: int) -> bool:
"""Return `True` for 4xx status codes, `False` otherwise."""
return 400 <= value <= 499 # noqa: PLR2004
return HTTPStatus.BAD_REQUEST <= value < HTTPStatus.INTERNAL_SERVER_ERROR


def is_status_code_server_error(value: int) -> bool:
"""Return `True` for 5xx status codes, `False` otherwise."""
return value >= 500 # noqa: PLR2004
return value >= HTTPStatus.INTERNAL_SERVER_ERROR


def is_status_code_successful(value: int) -> bool:
"""Return `True` for 2xx and 3xx status codes, `False` otherwise."""
return HTTPStatus.OK <= value < HTTPStatus.BAD_REQUEST
104 changes: 102 additions & 2 deletions tests/unit/_utils/test_sitemap.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import gzip
from datetime import datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock

from yarl import URL

from crawlee._utils.sitemap import Sitemap, SitemapUrl, parse_sitemap
from crawlee.http_clients._base import HttpClient
from crawlee._utils.sitemap import Sitemap, SitemapUrl, discover_valid_sitemaps, parse_sitemap
from crawlee.http_clients._base import HttpClient, HttpResponse

BASIC_SITEMAP = """
<?xml version="1.0" encoding="UTF-8"?>
Expand Down Expand Up @@ -46,6 +48,23 @@
}


def _make_mock_client(url_map: dict[str, tuple[int, bytes]]) -> AsyncMock:
async def send_request(url: str, **_kwargs: Any) -> HttpResponse:
status, body = 404, b''
for pattern, (s, b) in url_map.items():
if pattern in url:
status, body = s, b
break
response = MagicMock(spec=HttpResponse)
response.status_code = status
response.read = AsyncMock(return_value=body)
return response

client = AsyncMock(spec=HttpClient)
client.send_request.side_effect = send_request
return client


def compress_gzip(data: str) -> bytes:
"""Compress a string using gzip."""
return gzip.compress(data.encode())
Expand Down Expand Up @@ -246,3 +265,84 @@ async def test_sitemap_from_string() -> None:

assert len(sitemap.urls) == 5
assert set(sitemap.urls) == BASIC_RESULTS


async def test_discover_sitemap_from_robots_txt() -> None:
"""Sitemap URL found in robots.txt is yielded."""
robots_content = b'User-agent: *\nSitemap: http://example.com/custom-sitemap.xml'
http_client = _make_mock_client({'robots.txt': (200, robots_content)})

urls = [url async for url in discover_valid_sitemaps(['http://example.com/page'], http_client=http_client)]

assert urls == ['http://example.com/custom-sitemap.xml']


async def test_discover_sitemap_from_common_paths() -> None:
"""Sitemap is found at common paths when robots.txt has none."""
http_client = _make_mock_client(
{'/sitemap.xml': (200, b''), '/sitemap.txt': (200, b''), '/sitemap_index.xml': (200, b'')}
)

urls = [url async for url in discover_valid_sitemaps(['http://example.com/page'], http_client=http_client)]

assert urls == [
'http://example.com/sitemap.xml',
'http://example.com/sitemap.txt',
'http://example.com/sitemap_index.xml',
]


async def test_discover_sitemap_from_input_url() -> None:
"""Input URL that is already a sitemap is yielded directly without checking common paths."""
http_client = _make_mock_client({'/sitemap.txt': (200, b'')})

urls = [url async for url in discover_valid_sitemaps(['http://example.com/sitemap.xml'], http_client=http_client)]

assert urls == ['http://example.com/sitemap.xml']


async def test_discover_sitemap_deduplication() -> None:
"""Sitemap URL found in robots.txt is not yielded again from common paths check."""
robots_content = b'User-agent: *\nSitemap: http://example.com/sitemap.xml'
http_client = _make_mock_client(
{
'robots.txt': (200, robots_content),
'/sitemap.xml': (200, b''),
}
)

urls = [url async for url in discover_valid_sitemaps(['http://example.com/page'], http_client=http_client)]

assert urls == ['http://example.com/sitemap.xml']


async def test_discover_sitemaps_multiple_domains() -> None:
"""Sitemaps from multiple domains are all discovered."""
http_client = _make_mock_client(
{
'domain-a.com/sitemap.xml': (200, b''),
'domain-b.com/sitemap.xml': (200, b''),
}
)

urls = [
url
async for url in discover_valid_sitemaps(
['http://domain-a.com/page', 'http://domain-b.com/page'],
http_client=http_client,
)
]

assert set(urls) == {
'http://domain-a.com/sitemap.xml',
'http://domain-b.com/sitemap.xml',
}


async def test_discover_sitemap_url_without_host_skipped() -> None:
"""URLs without a host are skipped."""
http_client = _make_mock_client({})

urls = [url async for url in discover_valid_sitemaps(['not-a-valid-url'], http_client=http_client)]

assert urls == []
Loading