Skip to content

Commit

Permalink
Added support for NoneType to be used in an isinstance type guard…
Browse files Browse the repository at this point in the history
… or match statement. This addresses #4402.
  • Loading branch information
msfterictraut committed Jan 4, 2023
1 parent 5447798 commit ca078ab
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 5 deletions.
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ function narrowTypeBasedOnClassPattern(
type,
/* conditionFilter */ undefined,
(subjectSubtypeExpanded, subjectSubtypeUnexpanded) => {
if (!isClassInstance(subjectSubtypeExpanded)) {
if (!isNoneInstance(subjectSubtypeExpanded) && !isClassInstance(subjectSubtypeExpanded)) {
return subjectSubtypeUnexpanded;
}

Expand All @@ -521,7 +521,7 @@ function narrowTypeBasedOnClassPattern(
// if the types match exactly or the subtype is a final class and
// therefore cannot be subclassed.
if (!evaluator.assignType(subjectSubtypeExpanded, classInstance)) {
if (!ClassType.isFinal(subjectSubtypeExpanded)) {
if (isClass(subjectSubtypeExpanded) && !ClassType.isFinal(subjectSubtypeExpanded)) {
return subjectSubtypeExpanded;
}
}
Expand Down
8 changes: 7 additions & 1 deletion packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,13 @@ function narrowTypeForIsInstance(

if (isInstanceCheck) {
if (isNoneInstance(subtype)) {
const containsNoneType = classTypeList.some((t) => isNoneTypeClass(t));
const containsNoneType = classTypeList.some((t) => {
if (isNoneTypeClass(t)) {
return true;
}
return isInstantiableClass(t) && ClassType.isBuiltIn(t, 'NoneType');
});

if (isPositiveTest) {
return containsNoneType ? subtype : undefined;
} else {
Expand Down
5 changes: 5 additions & 0 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,11 @@ export function convertToInstance(type: Type, includeSubclasses = true): Type {
}
}

// Handle NoneType as a special case.
if (TypeBase.isInstantiable(subtype) && ClassType.isBuiltIn(subtype, 'NoneType')) {
return NoneType.createInstance();
}

return ClassType.cloneAsInstance(subtype, includeSubclasses);
}

Expand Down
18 changes: 16 additions & 2 deletions packages/pyright-internal/src/tests/samples/match10.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# This sample tests the reportMatchNotExhaustive diagnostic check.

from types import NoneType
from typing import Literal
from enum import Enum


def func1(subj: Literal["a", "b"], cond: bool):
# This should generate an error if reportMatchNotExhaustive is enabled.
match subj:
case "a":
pass

case "b" if cond:
pass

Expand All @@ -19,11 +21,13 @@ def func2(subj: object):
case int():
pass


def func3(subj: object):
match subj:
case object():
pass


def func4(subj: tuple[str] | tuple[int]):
match subj[0]:
case str():
Expand All @@ -32,15 +36,17 @@ def func4(subj: tuple[str] | tuple[int]):
case int():
pass


def func5(subj: Literal[1, 2, 3]):
# This should generate an error if reportMatchNotExhaustive is enabled.
match subj:
case 1 | 2:
pass


class Color(Enum):
red = 0
green= 1
green = 1
blue = 2


Expand Down Expand Up @@ -70,8 +76,16 @@ def func7() -> int:
class SingleColor(Enum):
red = 0


def func8(x: SingleColor) -> int:
match x:
case SingleColor.red:
return 1


def func9(x: int | None):
match x:
case NoneType():
return 1
case int():
return 2
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This sample exercises the type analyzer's isinstance type narrowing logic.

from types import NoneType
from typing import List, Optional, Sized, Type, TypeVar, Union, Any


Expand Down Expand Up @@ -160,3 +161,10 @@ def func8(a: int | list[int] | dict[str, int] | None):
reveal_type(a, expected_text="int | list[int] | None")
else:
reveal_type(a, expected_text="dict[str, int]")


def func9(a: int | None):
if not isinstance(a, NoneType):
reveal_type(a, expected_text="int")
else:
reveal_type(a, expected_text="None")

0 comments on commit ca078ab

Please sign in to comment.