diff --git a/testtools/testcase.py b/testtools/testcase.py index e07d1ecb..c235dd03 100644 --- a/testtools/testcase.py +++ b/testtools/testcase.py @@ -23,7 +23,7 @@ import types import unittest from collections.abc import Callable, Iterator -from typing import TypeVar, cast +from typing import TYPE_CHECKING, TypeVar, cast, overload from unittest.case import SkipTest T = TypeVar("T") @@ -56,6 +56,12 @@ TestResult, ) +if TYPE_CHECKING: + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + # Circular import: fixtures imports gather_details from here, we import # fixtures, leading to gather_details not being available and fixtures being # unable to import it. @@ -493,9 +499,27 @@ def assertIsInstance( # type: ignore[override] matcher = IsInstance(klass) self.assertThat(obj, matcher, msg or "") + @overload # type: ignore[override] + def assertRaises( + self, + expected_exception: type[BaseException] | tuple[type[BaseException]], + callable: Callable[..., object], + *args: object, + **kwargs: object, + ) -> BaseException: ... + + @overload # type: ignore[override] + def assertRaises( + self, + expected_exception: type[BaseException] | tuple[type[BaseException]], + callable: None = ..., + *args: object, + **kwargs: object, + ) -> "_AssertRaisesContext": ... + def assertRaises( # type: ignore[override] self, - expected_exception: type[BaseException], + expected_exception: type[BaseException] | tuple[type[BaseException]], callable: Callable[..., object] | None = None, *args: object, **kwargs: object, @@ -1206,7 +1230,10 @@ class _AssertRaisesContext: """ def __init__( - self, expected: type[BaseException], test_case: TestCase, msg: str | None = None + self, + expected: type[BaseException] | tuple[type[BaseException]], + test_case: TestCase, + msg: str | None = None, ) -> None: """Construct an `_AssertRaisesContext`. @@ -1219,7 +1246,7 @@ def __init__( self.msg = msg self.exception: BaseException | None = None - def __enter__(self) -> "_AssertRaisesContext": + def __enter__(self) -> "Self": return self def __exit__(