diff --git a/src/link_scraper.py b/src/link_scraper.py index 5b189b1..7061440 100644 --- a/src/link_scraper.py +++ b/src/link_scraper.py @@ -10,6 +10,15 @@ from rich.style import StyleType from rich.text import Text +class NonHtmlContentError(httpx.HTTPError): + """Raised when the site's Content-Type header indicates non-HTML content.""" + + def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response) -> None: + super().__init__(message) + self.request = request + self.response = response + + async def get_page_links( client: httpx.AsyncClient, url: httpx.URL, @@ -21,20 +30,30 @@ async def get_page_links( This function will also resolve relative URLs. Non http/https schemas will not be included. """ - res = await client.get(url) - if res.is_redirect and follow_redirects: - if not res.has_redirect_location: - raise httpx.HTTPStatusError( - f"Redirect response '{res.status_code} {res.reason_phrase}' " - "for url '{res.url} without Location header", - request=res.request, - response=res, + async with client.stream("GET", url) as res: + if res.is_redirect and follow_redirects: + if not res.has_redirect_location: + raise httpx.HTTPStatusError( + f"Redirect response '{res.status_code} {res.reason_phrase}' " + "for url '{res.url} without Location header", + request=res.request, + response=res, + ) + location = res.headers["Location"] + return await get_page_links( + client, httpx.URL(urljoin(str(url), location)), follow_redirects=follow_redirects ) - location = res.headers["Location"] - return await get_page_links(client, httpx.URL(urljoin(str(url), location)), follow_redirects=follow_redirects) - res.raise_for_status() - html = res.text + res.raise_for_status() + + # Make sure that we're getting back HTML content + content_type = res.headers.get("Content-Type", "") + if not content_type.startswith("text/html"): + raise NonHtmlContentError("The site content type isn't HTML", request=res.request, response=res) + + # Only read the rest of the data here, this prevents pulling large non-HTML files + await res.aread() + html = res.text soup = BeautifulSoup(html, features="html.parser") anchors = soup.find_all("a") @@ -63,12 +82,23 @@ def standard_urlmap_exception_suppressor(exc: Exception, url: httpx.URL) -> bool print_exc("Got ", (f"code {exc.response.status_code}", "red")) return True + if isinstance(exc, NonHtmlContentError): + print_exc( + "Got ", + ("Non-HTML Content-Type Header", "red"), + ", (", + (str(exc.response.headers.get("Content-Type", "")), "orange"), + ")", + ) + return True + if isinstance(exc, httpx.TransportError): print_exc("Got ", (exc.__class__.__qualname__, "red"), ", (", (str(exc), "orange"), ")") return True if isinstance(exc, ParserRejectedMarkup): print_exc("Parsing failure: ", ("Invalid HTML", "red")) + return True return False