diff --git a/.coveragerc.toml b/.coveragerc.toml index 4ca5d2808bd..e1a2dc7b42c 100644 --- a/.coveragerc.toml +++ b/.coveragerc.toml @@ -19,4 +19,5 @@ exclude_also = [ 'assert False', ': \.\.\.(\s*#.*)?$', '^ +\.\.\.$', + 'pytest.fail\(' ] diff --git a/tests/conftest.py b/tests/conftest.py index 5a9c26628d2..336c0e14d40 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ from uuid import uuid4 import pytest +import trustme from multidict import CIMultiDict from yarl import URL @@ -38,16 +39,6 @@ from aiohttp.http import WS_KEY, HttpVersion11 from aiohttp.test_utils import get_unused_port_socket, loop_context -try: - import trustme - - # Check if the CA is available in runtime, MacOS on Py3.10 fails somehow - trustme.CA() - - TRUSTME: bool = True -except ImportError: - TRUSTME = False - def pytest_configure(config: pytest.Config) -> None: # On Windows with Python 3.10/3.11, proxy.py's threaded mode can leave @@ -122,8 +113,6 @@ def blockbuster(request: pytest.FixtureRequest) -> Iterator[None]: @pytest.fixture def tls_certificate_authority() -> trustme.CA: - if not TRUSTME: - pytest.xfail("trustme is not supported") return trustme.CA() @@ -213,8 +202,6 @@ def unix_sockname( # Ref: https://unix.stackexchange.com/a/367012/27133 sock_file_name = "unix.sock" - unique_prefix = f"{uuid4()!s}-" - unique_prefix_len = len(unique_prefix.encode()) root_tmp_dir = Path("/tmp").resolve() os_tmp_dir = Path(os.getenv("TMPDIR", "/tmp")).resolve() @@ -249,7 +236,7 @@ def assert_sock_fits(sock_path: str) -> None: unique_paths = [p for n, p in enumerate(paths) if p not in paths[:n]] paths_num = len(unique_paths) - for num, tmp_dir_path in enumerate(paths, 1): + for num, tmp_dir_path in enumerate(paths, 1): # pragma: no branch with make_tmp_dir(tmp_dir_path) as tmps: tmpd = Path(tmps).resolve() sock_path = str(tmpd / sock_file_name) @@ -261,12 +248,6 @@ def assert_sock_fits(sock_path: str) -> None: assert_sock_fits(sock_path) if sock_path_len <= max_sock_len: - if max_sock_len - sock_path_len >= unique_prefix_len: - # If we're lucky to have extra space in the path, - # let's also make it more unique - sock_path = str(tmpd / "".join((unique_prefix, sock_file_name))) - # Double-checking it: - assert_sock_fits(sock_path) yield sock_path return diff --git a/tests/test_classbasedview.py b/tests/test_classbasedview.py index be40b4028ab..a9a02c9f83b 100644 --- a/tests/test_classbasedview.py +++ b/tests/test_classbasedview.py @@ -28,7 +28,7 @@ async def get(self) -> web.StreamResponse: async def test_render_unknown_method() -> None: class MyView(View): async def get(self) -> web.StreamResponse: - return web.Response(text="OK") + assert False options = get @@ -43,7 +43,7 @@ async def get(self) -> web.StreamResponse: async def test_render_unsupported_method() -> None: class MyView(View): async def get(self) -> web.StreamResponse: - return web.Response(text="OK") + assert False options = delete = get diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index da3ad8a5b4a..971e7576d52 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -24,8 +24,8 @@ import brotlicffi as brotli except ImportError: import brotli -except ImportError: - brotli = None # pragma: no cover +except ImportError: # pragma: no cover + brotli = None try: from backports.zstd import ZstdCompressor @@ -397,7 +397,7 @@ async def handler(request: web.Request) -> web.Response: client = await aiohttp_client(app) async def data_gen() -> AsyncIterator[bytes]: - for _ in range(2): + for _ in range(2): # pragma: no branch yield b"just data" await asyncio.sleep(0.1) @@ -430,7 +430,7 @@ async def handler(request: web.Request) -> web.Response: client = await aiohttp_client(app) async def data_gen() -> AsyncIterator[bytes]: - for _ in range(2): + for _ in range(2): # pragma: no branch yield b"just data" await asyncio.sleep(0.1) diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py index 222e912d3a9..bfb1ccc39e3 100644 --- a/tests/test_client_middleware.py +++ b/tests/test_client_middleware.py @@ -633,10 +633,8 @@ async def test_request_middleware_overrides_session_middleware_with_specific( request_middleware_called = False async def handler(request: web.Request) -> web.Response: - auth_header = request.headers.get("Authorization") - if auth_header: - return web.Response(text=f"Auth: {auth_header}") - return web.Response(text="No auth") + auth_header = request.headers["Authorization"] + return web.Response(text=f"Auth: {auth_header}") async def session_middleware( request: ClientRequest, handler: ClientHandlerType diff --git a/tests/test_client_session.py b/tests/test_client_session.py index e40001bb307..9853aea2347 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -629,8 +629,8 @@ async def create_connection( session._connector, "_release", autospec=True, spec_set=True ): with pytest.raises(UnexpectedException): - async with session.request("get", "http://example.com") as resp: - await resp.text() + async with session.request("get", "http://example.com"): + pass # normally called during garbage collection. triggers an exception # if the connection wasn't already closed diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index c58613c6ca9..22c03f7c6c4 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -76,7 +76,7 @@ async def handler(request: web.Request) -> NoReturn: with pytest.raises(WSMessageTypeError): await resp.receive_bytes() - await resp.close() + await resp.close() async def test_recv_bytes_after_close(aiohttp_client: AiohttpClient) -> None: @@ -97,7 +97,7 @@ async def handler(request: web.Request) -> NoReturn: match=f"Received message {WSMsgType.CLOSE}:.+ is not WSMsgType.BINARY", ): await resp.receive_bytes() - await resp.close() + await resp.close() async def test_send_recv_bytes(aiohttp_client: AiohttpClient) -> None: @@ -142,8 +142,7 @@ async def handler(request: web.Request) -> NoReturn: with pytest.raises(WSMessageTypeError): await resp.receive_str() - - await resp.close() + await resp.close() async def test_recv_text_after_close(aiohttp_client: AiohttpClient) -> None: @@ -164,7 +163,7 @@ async def handler(request: web.Request) -> NoReturn: match=f"Received message {WSMsgType.CLOSE}:.+ is not WSMsgType.TEXT", ): await resp.receive_str() - await resp.close() + await resp.close() async def test_send_recv_json(aiohttp_client: AiohttpClient) -> None: @@ -1151,13 +1150,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_send_recv_compress_wbit_error(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse() - await ws.prepare(request) - - msg = await ws.receive_bytes() - await ws.send_bytes(msg + b"/answer") - await ws.close() - return ws + assert False app = web.Application() app.router.add_route("GET", "/", handler) diff --git a/tests/test_connector.py b/tests/test_connector.py index a3fd3626157..5b8f1e36db9 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -217,8 +217,6 @@ async def test_del(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: "connections": mock.ANY, "message": "Unclosed connector", } - if loop.get_debug(): - msg["source_traceback"] = mock.ANY exc_handler.assert_called_with(loop, msg) @@ -2118,9 +2116,6 @@ async def test_cleanup3(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> async def test_cleanup_closed( loop: asyncio.AbstractEventLoop, mocker: MockerFixture ) -> None: - if not hasattr(loop, "__dict__"): - pytest.skip("can not override loop attributes") - m = mocker.spy(loop, "call_at") conn = aiohttp.BaseConnector(enable_cleanup_closed=True) @@ -4121,10 +4116,7 @@ async def _resolve_host( first_conn = next(iter(conn._conns.values()))[0][0] assert first_conn.transport is not None - try: - _sslcontext = first_conn.transport._ssl_protocol._sslcontext # type: ignore[attr-defined] - except AttributeError: - _sslcontext = first_conn.transport._sslcontext # type: ignore[attr-defined] + _sslcontext = first_conn.transport._ssl_protocol._sslcontext # type: ignore[attr-defined] assert _sslcontext is client_ssl_ctx r.close() @@ -4531,10 +4523,6 @@ async def await_connection_and_check_waiters() -> None: connection.close() async def allow_connection_and_add_dummy_waiter() -> None: - # `asyncio.gather` may execute coroutines not in order. - # Skip one event loop run cycle in such a case. - if connection_key not in connector._waiters: - await asyncio.sleep(0) list(connector._waiters[connection_key])[0].set_result(None) del connector._waiters[connection_key] connector._waiters[connection_key][dummy_waiter] = None diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index f99791af4a3..e237aad6a88 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -35,7 +35,7 @@ import brotlicffi as brotli except ImportError: import brotli -except ImportError: +except ImportError: # pragma: no cover brotli = None try: diff --git a/tests/test_imports.py b/tests/test_imports.py index 0d220a656ed..0779e7fc779 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -58,7 +58,7 @@ def test_import_time(pytester: pytest.Pytester) -> None: finally: if old_path is None: os.environ.pop("PYTHONPATH") - else: + else: # pragma: no cover os.environ["PYTHONPATH"] = old_path assert best_time_ms < IMPORT_TIME_THRESHOLD_MS diff --git a/tests/test_loop.py b/tests/test_loop.py index eec0057748a..0feaf5ce1a5 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -48,7 +48,7 @@ def target() -> None: with loop_context() as loop: assert asyncio.get_event_loop() is loop loop.run_until_complete(test_subprocess_co(loop)) - except Exception as exc: + except Exception as exc: # pragma: no cover nonlocal child_exc child_exc = exc diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 5444817d5a4..9422a68fbb6 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -242,7 +242,7 @@ async def test_read_incomplete_body_chunked(self) -> None: obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = b"" with pytest.raises(ValueError): - for _ in range(4): + for _ in range(4): # pragma: no branch result += await obj.read_chunk(7) assert b"Hello, World!\r\n-" == result diff --git a/tests/test_payload.py b/tests/test_payload.py index dd25ccfc459..205a3efdf81 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -65,7 +65,7 @@ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: assert False async def write(self, writer: AbstractStreamWriter) -> None: - pass + """Dummy write.""" def test_register_type(registry: payload.PayloadRegistry) -> None: @@ -146,8 +146,7 @@ def test_string_io_payload() -> None: def test_async_iterable_payload_default_content_type() -> None: async def gen() -> AsyncIterator[bytes]: - return - yield b"abc" # type: ignore[unreachable] # pragma: no cover + yield b"abc" # pragma: no cover p = payload.AsyncIterablePayload(gen()) assert p.content_type == "application/octet-stream" @@ -155,8 +154,7 @@ async def gen() -> AsyncIterator[bytes]: def test_async_iterable_payload_explicit_content_type() -> None: async def gen() -> AsyncIterator[bytes]: - return - yield b"abc" # type: ignore[unreachable] # pragma: no cover + yield b"abc" # pragma: no cover p = payload.AsyncIterablePayload(gen(), content_type="application/custom") assert p.content_type == "application/custom" diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 13e494d24cd..e19df43e7e2 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -2,7 +2,7 @@ import gc import ipaddress import socket -from collections.abc import Awaitable, Callable, Collection, Generator, Iterable +from collections.abc import Awaitable, Callable, Collection, Generator from ipaddress import ip_address from typing import Any, NamedTuple from unittest.mock import Mock, create_autospec, patch @@ -21,7 +21,7 @@ import aiodns getaddrinfo = hasattr(aiodns.DNSResolver, "getaddrinfo") -except ImportError: +except ImportError: # pragma: no cover aiodns = None # type: ignore[assignment] getaddrinfo = False @@ -110,11 +110,6 @@ def __init__(self, host: str) -> None: self.service = None -class FakeQueryResult: - def __init__(self, host: str) -> None: - self.host = host - - async def fake_aiodns_getaddrinfo_ipv4_result( hosts: Collection[str], ) -> FakeAIODNSAddrInfoIPv4Result: @@ -133,10 +128,6 @@ async def fake_aiodns_getnameinfo_ipv6_result( return FakeAIODNSNameInfoIPv6Result(host) -async def fake_query_result(result: Iterable[str]) -> list[FakeQueryResult]: - return [FakeQueryResult(host=h) for h in result] - - def fake_addrinfo(hosts: Collection[str]) -> Callable[..., Awaitable[_AddrInfo4]]: async def fake(*args: Any, **kwargs: Any) -> _AddrInfo4: if not hosts: @@ -440,11 +431,6 @@ def test_aio_dns_is_default() -> None: assert DefaultResolver is AsyncResolver -@pytest.mark.skipif(getaddrinfo, reason="aiodns <3.2.0 required") -def test_threaded_resolver_is_default() -> None: - assert DefaultResolver is ThreadedResolver - - @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_sharing( dns_resolver_manager: _DNSResolverManager, diff --git a/tests/test_run_app.py b/tests/test_run_app.py index f865de13a7e..9e1fdb63176 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -45,14 +45,14 @@ del _has_unix_domain_socks, _abstract_path_failed HAS_IPV6: bool = socket.has_ipv6 -if HAS_IPV6: +if HAS_IPV6: # pragma: no branch # The socket.has_ipv6 flag may be True if Python was built with IPv6 # support, but the target system still may not have it. # So let's ensure that we really have IPv6 support. try: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM): pass - except OSError: + except OSError: # pragma: no cover HAS_IPV6 = False @@ -692,12 +692,10 @@ def test_sigint() -> None: skip_if_on_windows() with subprocess.Popen( - [sys.executable, "-u", "-c", _script_test_signal], + (sys.executable, "-u", "-c", _script_test_signal), stdout=subprocess.PIPE, ) as proc: - for line in proc.stdout: # type: ignore[union-attr] - if line.startswith(b"======== Running on"): - break + assert proc.stdout.readline().startswith(b"======== Running on") # type: ignore[union-attr] proc.send_signal(signal.SIGINT) assert proc.wait() == 0 @@ -706,12 +704,10 @@ def test_sigterm() -> None: skip_if_on_windows() with subprocess.Popen( - [sys.executable, "-u", "-c", _script_test_signal], + (sys.executable, "-u", "-c", _script_test_signal), stdout=subprocess.PIPE, ) as proc: - for line in proc.stdout: # type: ignore[union-attr] - if line.startswith(b"======== Running on"): - break + assert proc.stdout.readline().startswith(b"======== Running on") # type: ignore[union-attr] proc.terminate() assert proc.wait() == 0 @@ -1110,7 +1106,7 @@ def test_shutdown_timeout_handler(self, unused_port_socket: socket.socket) -> No async def task() -> None: nonlocal finished await asyncio.sleep(2) - finished = True + finished = True # pragma: no cover t, connection_count = self.run_app(sock, 1, task) @@ -1159,7 +1155,7 @@ async def test(sess: ClientSession) -> None: # Use a new session to try and open a new connection. async with ClientSession() as sess: async with sess.get(f"http://127.0.0.1:{port}/"): - pass + assert False # Should fail before here assert finished is False t, connection_count = self.run_app(sock, 10, task, test) @@ -1261,7 +1257,7 @@ async def ws_handler(request: web.Request) -> web.WebSocketResponse: await ws.prepare(request) request.app[WS].add(ws) async for msg in ws: - pass + assert False # No messages actually sent nonlocal server_finished server_finished = True return ws @@ -1278,7 +1274,7 @@ async def test() -> None: pass async for msg in ws: - pass + assert False # No messages actually sent nonlocal client_finished client_finished = True @@ -1316,8 +1312,8 @@ async def test() -> None: async def test_resp(sess: ClientSession) -> None: t = ClientTimeout(total=0.4) with pytest.raises(asyncio.TimeoutError): - async with sess.get(f"http://127.0.0.1:{port}/", timeout=t) as resp: - assert await resp.text() == "FOO" + async with sess.get(f"http://127.0.0.1:{port}/", timeout=t): + assert False # Should timeout before this actions.append("CANCELLED") async with ClientSession() as sess: diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py index 74770d108d8..2464be08c98 100644 --- a/tests/test_tcp_helpers.py +++ b/tests/test_tcp_helpers.py @@ -13,7 +13,7 @@ try: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM): pass - except OSError: + except OSError: # pragma: no cover has_ipv6 = False diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index bbab015061f..dba8198300a 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -41,11 +41,8 @@ async def websocket_handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == "close": - await ws.close() - else: - await ws.send_str(msg.data + "/answer") + assert msg.type == aiohttp.WSMsgType.TEXT + await ws.send_str(msg.data + "/answer") return ws diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index 1bc7a7b638e..6603389f16d 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -374,8 +374,8 @@ def test_add_static_path_checks( """Test that static paths must exist and be directories.""" with pytest.raises(ValueError, match="does not exist"): router.add_static("/", tmp_path / "does-not-exist") - with pytest.raises(ValueError, match="is not a directory"): - router.add_static("/", __file__) + with pytest.raises(ValueError, match="is not a directory"): + router.add_static("/", __file__) def test_add_static_path_resolution(router: web.UrlDispatcher) -> None: @@ -1300,16 +1300,9 @@ def test_frozen_app_on_subapp(app: web.Application) -> None: def test_set_options_route(router: web.UrlDispatcher) -> None: resource = router.add_static("/static", pathlib.Path(aiohttp.__file__).parent) - options = None - for route in resource: - if route.method == "OPTIONS": - options = route - assert options is None + assert all(r.method != "OPTIONS" for r in resource) resource.set_options_route(make_handler()) - for route in resource: - if route.method == "OPTIONS": - options = route - assert options is not None + assert any(r.method == "OPTIONS" for r in resource) with pytest.raises(RuntimeError): resource.set_options_route(make_handler()) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 71dc53b500e..730d662ced4 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1991,7 +1991,7 @@ async def handler(request: web.Request) -> web.StreamResponse: ): await resp.drain() await asyncio.sleep(10) - return resp + assert False app = web.Application() app.router.add_route("GET", "/", handler) diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index ee30e485f1b..3788e71f874 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -4,7 +4,7 @@ async def serve(request: web.BaseRequest) -> web.Response: - return web.Response() + assert False async def test_repr() -> None: diff --git a/tests/test_web_response.py b/tests/test_web_response.py index becbfedc965..76d52cca1aa 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -1412,9 +1412,7 @@ async def test_response_prepared_after_header_preparation() -> None: async def _strip_server(req: web.Request, res: web.Response) -> None: assert "Server" in res.headers - - if "Server" in res.headers: - del res.headers["Server"] + del res.headers["Server"] app = mock.create_autospec(web.Application, spec_set=True) app.on_response_prepare = aiosignal.Signal(app) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 21249daa371..87be2db182b 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -4,7 +4,7 @@ import pathlib import socket from collections.abc import Iterable, Iterator -from typing import NoReturn, Protocol +from typing import Protocol from unittest import mock import pytest @@ -60,19 +60,6 @@ def hello_txt( return hello[encoding] -@pytest.fixture -def loop_with_mocked_native_sendfile( - loop: asyncio.AbstractEventLoop, -) -> Iterator[asyncio.AbstractEventLoop]: - def sendfile(transport: object, fobj: object, offset: int, count: int) -> NoReturn: - if count == 0: - raise ValueError("count must be a positive integer (got 0)") - raise NotImplementedError - - with mock.patch.object(loop, "sendfile", sendfile): - yield loop - - @pytest.fixture(params=["sendfile", "no_sendfile"], ids=["sendfile", "no_sendfile"]) def sender(request: SubRequest, loop: asyncio.AbstractEventLoop) -> Iterator[_Sender]: sendfile_mock = None @@ -154,12 +141,10 @@ async def handler(request: web.Request) -> web.FileResponse: async def test_zero_bytes_file_mocked_native_sendfile( aiohttp_client: AiohttpClient, - loop_with_mocked_native_sendfile: asyncio.AbstractEventLoop, ) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" async def handler(request: web.Request) -> web.FileResponse: - asyncio.set_event_loop(loop_with_mocked_native_sendfile) return web.FileResponse(filepath) app = web.Application() diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 488e6f8d843..2cd364e0317 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -388,8 +388,7 @@ async def on_request(request: web.Request) -> web.Response: except asyncio.CancelledError: event.set() raise - else: - raise web.HTTPInternalServerError() + assert False app = web.Application() app.router.add_route("GET", "/", on_request) diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index 468b39a1c52..144bd9cd03e 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -570,7 +570,7 @@ async def test_access_special_resource( unix_sockname: str, aiohttp_client: AiohttpClient ) -> None: """Test access to non-regular files is forbidden using a UNIX domain socket.""" - if not getattr(socket, "AF_UNIX", None): + if not getattr(socket, "AF_UNIX", None): # pragma: no cover pytest.skip("UNIX domain sockets not supported") my_special = pathlib.Path(unix_sockname) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 0e41faa21f2..1e202649c6a 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -62,6 +62,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: resp = await ws.receive() assert resp.data == expected_value + await ws.receive() # Handle close + async def test_websocket_json_invalid_message( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient @@ -69,14 +71,10 @@ async def test_websocket_json_invalid_message( async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) - try: + with pytest.raises(ValueError): await ws.receive_json() - except ValueError: - await ws.send_str("ValueError was raised") - else: - raise Exception("No Exception") - finally: - await ws.close() + await ws.send_str("ValueError was raised") + await ws.close() return ws app = web.Application() @@ -90,6 +88,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: data = await ws.receive_str() assert "ValueError was raised" in data + await ws.receive() # Handle close + async def test_websocket_send_json( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient @@ -115,6 +115,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: data = await ws.receive_json() assert data["test"] == expected_value + await ws.receive() # Handle close + async def test_websocket_receive_json( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient @@ -142,6 +144,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: resp = await ws.receive() assert resp.data == expected_value + await ws.receive() # Handle close + async def test_send_recv_text( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient @@ -308,9 +312,6 @@ async def handler(request: web.Request) -> web.WebSocketResponse: msg = await ws.receive() assert msg.type == WSMsgType.CLOSING - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - await asyncio.sleep(0) msg = await ws.receive() @@ -348,9 +349,6 @@ async def handler(request: web.Request) -> web.WebSocketResponse: msg = await ws.receive() assert msg.type == WSMsgType.CLOSING - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - await asyncio.sleep(0) msg = await ws.receive() @@ -948,9 +946,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: messages = [] async for msg in ws: messages.append(msg) - if "stop" == msg.data: - await ws.send_str("stopping") - await ws.close() + assert "stop" == msg.data + await ws.send_str("stopping") + await ws.close() assert 1 == len(messages) assert messages[0].type == WSMsgType.TEXT @@ -1001,6 +999,8 @@ async def handler(request: web.Request) -> web.StreamResponse: data = await ws.receive_str() assert data == "OK" + await ws.receive() # Handle close + async def test_receive_str_nonstring( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient @@ -1022,6 +1022,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: with pytest.raises(TypeError): await ws.receive_str() + await ws.receive() # Handle close + async def test_receive_bytes_nonbytes( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 6de09a2cb00..26d1a275327 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -68,8 +68,8 @@ def build_frame( compressobj = ZLibBackend.compressobj(wbits=-9) message = compressobj.compress(message) message = message + compressobj.flush(ZLibBackend.Z_SYNC_FLUSH) - if message.endswith(WS_DEFLATE_TRAILING): - message = message[:-4] + assert message.endswith(WS_DEFLATE_TRAILING) + message = message[:-4] msg_length = len(message) if is_fin: @@ -595,7 +595,6 @@ def test_parse_compress_error_frame(parser: PatchableWebSocketReader) -> None: with pytest.raises(WebSocketError) as ctx: parser.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) - parser.parse_frame(b"1") assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR @@ -604,7 +603,6 @@ def test_parse_no_compress_frame_single(out: WebSocketDataQueue) -> None: parser_no_compress = PatchableWebSocketReader(out, 0, compress=False) with pytest.raises(WebSocketError) as ctx: parser_no_compress.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) - parser_no_compress.parse_frame(b"1") assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR