diff --git a/pyrefly/lib/alt/narrow.rs b/pyrefly/lib/alt/narrow.rs index 285ddbc733..650799717c 100644 --- a/pyrefly/lib/alt/narrow.rs +++ b/pyrefly/lib/alt/narrow.rs @@ -398,6 +398,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { for right in self.as_class_info(right.clone()) { res.push(self.distribute_over_union(left, |l| { if let Some((tparams, right)) = self.unwrap_isinstance_target(l, &right) { + if matches!(&right, Type::ClassType(cls) if cls.is_builtin("type")) + && matches!(l, Type::Type(_) | Type::ClassDef(_)) + { + return l.clone(); + } let (vs, right) = self .solver() .fresh_quantified(&tparams, right, self.uniques); @@ -627,14 +632,24 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let mut res = Vec::new(); for (right, allows_negative_narrow) in self.expr_as_class_info(right_expr, errors) { if allows_negative_narrow - && let Some(left_untyped) = - self.untype_opt(left.clone(), right_expr.range(), errors) && let Some((tparams, right)) = self.unwrap_class_object_silently(&right) { let (vs, right) = self .solver() .fresh_quantified(&tparams, right, self.uniques); - res.push(self.issubclass_result(self.subtract(&left_untyped, &right), left)); + res.push(self.distribute_over_union(left, |left| { + let Some(left_untyped) = + self.untype_opt(left.clone(), right_expr.range(), errors) + else { + return left.clone(); + }; + let instance_result = self.subtract(&left_untyped, &right); + if instance_result.is_never() { + instance_result + } else { + self.issubclass_result(instance_result, left) + } + })); // These are safe to ignore, as the only possible specialization errors are handled elsewhere: // * If `left` is an invalid specialization, the error has already been reported at its definition site. // * Unsafe runtime protocol overlaps are separately checked for in special_calls.rs. diff --git a/pyrefly/lib/test/narrow.rs b/pyrefly/lib/test/narrow.rs index a787b97a4e..25278d95fd 100644 --- a/pyrefly/lib/test/narrow.rs +++ b/pyrefly/lib/test/narrow.rs @@ -1392,6 +1392,22 @@ def test_isinstance_then_issubclass(x: object) -> None: "#, ); +testcase!( + test_isinstance_type_and_issubclass_else, + r#" +from typing import Iterable, assert_type +import enum + +def main(categories: Iterable[str] | type[enum.Enum]) -> None: + cached_categories: tuple[str, ...] + if isinstance(categories, type) and issubclass(categories, enum.Enum): + cached_categories = tuple(member.value for member in categories) + else: + assert_type(categories, Iterable[str]) + cached_categories = tuple(categories) + "#, +); + testcase!( test_issubclass_with_metaclass_instance, r#"