From 53d6bf39d8614c317d5737ab560541e1df5da55c Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 16 Jan 2026 20:33:15 +0100 Subject: [PATCH 1/2] allow arbirary types in set.discard --- stdlib/@tests/test_cases/builtins/check_set.py | 4 ++++ stdlib/builtins.pyi | 2 +- stdlib/typing.pyi | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/stdlib/@tests/test_cases/builtins/check_set.py b/stdlib/@tests/test_cases/builtins/check_set.py index 604251b0bb67..5557b03f745e 100644 --- a/stdlib/@tests/test_cases/builtins/check_set.py +++ b/stdlib/@tests/test_cases/builtins/check_set.py @@ -12,3 +12,7 @@ def test_set_difference(x: set[Literal["foo", "bar"]], y: set[str], z: set[int]) assert_type(z - x, set[int]) assert_type(y - z, set[str]) assert_type(z - y, set[int]) + + +def test_set_discard(x: set[Literal["foo", "bar"]], key: str) -> None: + x.discard(key) # OK diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 5695b17ca36d..8c3c5dfb0670 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1264,7 +1264,7 @@ class set(MutableSet[_T]): def copy(self) -> set[_T]: ... def difference(self, *s: Iterable[Any]) -> set[_T]: ... def difference_update(self, *s: Iterable[Any]) -> None: ... - def discard(self, element: _T, /) -> None: ... + def discard(self, element: object, /) -> None: ... def intersection(self, *s: Iterable[Any]) -> set[_T]: ... def intersection_update(self, *s: Iterable[Any]) -> None: ... def isdisjoint(self, s: Iterable[Any], /) -> bool: ... diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi index af1d1650da41..7ff6a5b27887 100644 --- a/stdlib/typing.pyi +++ b/stdlib/typing.pyi @@ -721,7 +721,7 @@ class MutableSet(AbstractSet[_T]): @abstractmethod def add(self, value: _T, /) -> None: ... @abstractmethod - def discard(self, value: _T, /) -> None: ... + def discard(self, value: Any, /) -> None: ... # Mixin methods def clear(self) -> None: ... def pop(self) -> _T: ... From ade8cab600bdb81974ea80c87105d03af547187e Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sat, 28 Feb 2026 17:11:16 +0100 Subject: [PATCH 2/2] canonicalized set/frozenset signatures --- .../@tests/test_cases/builtins/check_set.py | 43 ++++++++++++++++++- stdlib/builtins.pyi | 22 +++++----- stdlib/typing.pyi | 2 +- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/stdlib/@tests/test_cases/builtins/check_set.py b/stdlib/@tests/test_cases/builtins/check_set.py index 5557b03f745e..89cb8683bfe1 100644 --- a/stdlib/@tests/test_cases/builtins/check_set.py +++ b/stdlib/@tests/test_cases/builtins/check_set.py @@ -14,5 +14,44 @@ def test_set_difference(x: set[Literal["foo", "bar"]], y: set[str], z: set[int]) assert_type(z - y, set[int]) -def test_set_discard(x: set[Literal["foo", "bar"]], key: str) -> None: - x.discard(key) # OK +def test_set_interface_overlapping_type(s: set[Literal["foo", "bar"]], y: set[str], key: str) -> None: + s.add(key) # type: ignore + s.discard(key) + s.remove(key) # type: ignore + s.difference_update(y) + s.intersection_update(y) + s.symmetric_difference_update(y) # type: ignore + s.update(y) # type: ignore + + assert_type(s.difference(y), set[Literal["foo", "bar"]]) + assert_type(s.intersection(y), set[Literal["foo", "bar"]]) + assert_type(s.isdisjoint(y), bool) + assert_type(s.issubset(y), bool) + assert_type(s.issuperset(y), bool) + assert_type(s.symmetric_difference(y), set[str]) + assert_type(s.union(y), set[str]) + + assert_type(s - y, set[Literal["foo", "bar"]]) + assert_type(s & y, set[Literal["foo", "bar"]]) + assert_type(s | y, set[str]) + assert_type(s ^ y, set[str]) + + s -= y + s &= y + s |= y # type: ignore + s ^= y # type: ignore + + +def test_frozenset_interface(s: frozenset[Literal["foo", "bar"]], y: frozenset[str]) -> None: + assert_type(s.difference(y), frozenset[Literal["foo", "bar"]]) + assert_type(s.intersection(y), frozenset[Literal["foo", "bar"]]) + assert_type(s.isdisjoint(y), bool) + assert_type(s.issubset(y), bool) + assert_type(s.issuperset(y), bool) + assert_type(s.symmetric_difference(y), frozenset[str]) + assert_type(s.union(y), frozenset[str]) + + assert_type(s - y, frozenset[Literal["foo", "bar"]]) + assert_type(s & y, frozenset[Literal["foo", "bar"]]) + assert_type(s | y, frozenset[str]) + assert_type(s ^ y, frozenset[str]) diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 8c3c5dfb0670..a22bb400adfe 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1262,16 +1262,16 @@ class set(MutableSet[_T]): def __init__(self, iterable: Iterable[_T], /) -> None: ... def add(self, element: _T, /) -> None: ... def copy(self) -> set[_T]: ... - def difference(self, *s: Iterable[Any]) -> set[_T]: ... - def difference_update(self, *s: Iterable[Any]) -> None: ... + def difference(self, *s: Iterable[object]) -> set[_T]: ... + def difference_update(self, *s: Iterable[object]) -> None: ... def discard(self, element: object, /) -> None: ... - def intersection(self, *s: Iterable[Any]) -> set[_T]: ... - def intersection_update(self, *s: Iterable[Any]) -> None: ... - def isdisjoint(self, s: Iterable[Any], /) -> bool: ... - def issubset(self, s: Iterable[Any], /) -> bool: ... - def issuperset(self, s: Iterable[Any], /) -> bool: ... + def intersection(self, *s: Iterable[object]) -> set[_T]: ... + def intersection_update(self, *s: Iterable[object]) -> None: ... + def isdisjoint(self, s: Iterable[object], /) -> bool: ... + def issubset(self, s: Iterable[object], /) -> bool: ... + def issuperset(self, s: Iterable[object], /) -> bool: ... def remove(self, element: _T, /) -> None: ... - def symmetric_difference(self, s: Iterable[_T], /) -> set[_T]: ... + def symmetric_difference(self, s: Iterable[_S], /) -> set[_T | _S]: ... def symmetric_difference_update(self, s: Iterable[_T], /) -> None: ... def union(self, *s: Iterable[_S]) -> set[_T | _S]: ... def update(self, *s: Iterable[_T]) -> None: ... @@ -1303,15 +1303,15 @@ class frozenset(AbstractSet[_T_co]): def copy(self) -> frozenset[_T_co]: ... def difference(self, *s: Iterable[object]) -> frozenset[_T_co]: ... def intersection(self, *s: Iterable[object]) -> frozenset[_T_co]: ... - def isdisjoint(self, s: Iterable[_T_co], /) -> bool: ... + def isdisjoint(self, s: Iterable[object], /) -> bool: ... def issubset(self, s: Iterable[object], /) -> bool: ... def issuperset(self, s: Iterable[object], /) -> bool: ... - def symmetric_difference(self, s: Iterable[_T_co], /) -> frozenset[_T_co]: ... + def symmetric_difference(self, s: Iterable[_S], /) -> frozenset[_T_co | _S]: ... def union(self, *s: Iterable[_S]) -> frozenset[_T_co | _S]: ... def __len__(self) -> int: ... def __contains__(self, o: object, /) -> bool: ... def __iter__(self) -> Iterator[_T_co]: ... - def __and__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ... + def __and__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ... def __or__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ... def __sub__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ... def __xor__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ... diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi index 7ff6a5b27887..af1d1650da41 100644 --- a/stdlib/typing.pyi +++ b/stdlib/typing.pyi @@ -721,7 +721,7 @@ class MutableSet(AbstractSet[_T]): @abstractmethod def add(self, value: _T, /) -> None: ... @abstractmethod - def discard(self, value: Any, /) -> None: ... + def discard(self, value: _T, /) -> None: ... # Mixin methods def clear(self) -> None: ... def pop(self) -> _T: ...