Skip to content

Commit

Permalink
Improved handling of Annotated and other special forms when they ar…
Browse files Browse the repository at this point in the history
…e used in runtime value expressions rather than annotations. This addresses microsoft#7049. (microsoft#7050)
  • Loading branch information
erictraut authored Jan 19, 2024
1 parent 01b1aa0 commit 0fc5682
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 49 deletions.
20 changes: 14 additions & 6 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ export class Checker extends ParseTreeWalker {

override visitReturn(node: ReturnNode): boolean {
let returnTypeResult: TypeResult;
let returnType: Type | undefined;

const enclosingFunctionNode = ParseTreeUtils.getEnclosingFunction(node);
const declaredReturnType = enclosingFunctionNode
Expand All @@ -923,6 +924,13 @@ export class Checker extends ParseTreeWalker {
returnTypeResult = { type: this._evaluator.getNoneType() };
}

returnType = returnTypeResult.type;

// If this type is a special form, use the special form instead.
if (returnType.specialForm) {
returnType = returnType.specialForm;
}

// If the enclosing function is async and a generator, the return
// statement is not allowed to have an argument. A syntax error occurs
// at runtime in this case.
Expand Down Expand Up @@ -952,7 +960,7 @@ export class Checker extends ParseTreeWalker {
if (
this._evaluator.assignType(
declaredReturnType,
returnTypeResult.type,
returnType,
diagAddendum,
new TypeVarContext(),
/* srcTypeVarContext */ undefined,
Expand Down Expand Up @@ -987,7 +995,7 @@ export class Checker extends ParseTreeWalker {
if (
this._evaluator.assignType(
adjustedReturnType,
returnTypeResult.type,
returnType,
diagAddendum,
/* destTypeVarContext */ undefined,
/* srcTypeVarContext */ undefined,
Expand All @@ -1010,7 +1018,7 @@ export class Checker extends ParseTreeWalker {
this._evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocMessage.returnTypeMismatch().format({
exprType: this._evaluator.printType(returnTypeResult.type),
exprType: this._evaluator.printType(returnType),
returnType: this._evaluator.printType(declaredReturnType),
}) + diagAddendum.getString(),
node.returnExpression ?? node,
Expand All @@ -1020,17 +1028,17 @@ export class Checker extends ParseTreeWalker {
}
}

if (isUnknown(returnTypeResult.type)) {
if (isUnknown(returnType)) {
this._evaluator.addDiagnostic(
DiagnosticRule.reportUnknownVariableType,
LocMessage.returnTypeUnknown(),
node.returnExpression ?? node
);
} else if (isPartlyUnknown(returnTypeResult.type)) {
} else if (isPartlyUnknown(returnType)) {
this._evaluator.addDiagnostic(
DiagnosticRule.reportUnknownVariableType,
LocMessage.returnTypePartiallyUnknown().format({
returnType: this._evaluator.printType(returnTypeResult.type, { expandTypeAlias: true }),
returnType: this._evaluator.printType(returnType, { expandTypeAlias: true }),
}),
node.returnExpression ?? node
);
Expand Down
23 changes: 16 additions & 7 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14675,8 +14675,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
// We'll abuse our internal types a bit by specializing it with
// a type argument anyway.
function createTypeGuardType(
errorNode: ParseNode,
classType: ClassType,
errorNode: ParseNode,
typeArgs: TypeResultWithNode[] | undefined,
flags: EvaluatorFlags
): Type {
Expand Down Expand Up @@ -14939,8 +14939,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}

function createConcatenateType(
errorNode: ParseNode,
classType: ClassType,
errorNode: ParseNode,
typeArgs: TypeResultWithNode[] | undefined,
flags: EvaluatorFlags
): Type {
Expand Down Expand Up @@ -14972,7 +14972,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return createSpecialType(classType, typeArgs, /* paramLimit */ undefined, /* allowParamSpec */ true);
}

function createAnnotatedType(errorNode: ParseNode, typeArgs: TypeResultWithNode[] | undefined): TypeResult {
function createAnnotatedType(
classType: ClassType,
errorNode: ParseNode,
typeArgs: TypeResultWithNode[] | undefined
): TypeResult {
if (typeArgs && typeArgs.length < 2) {
addError(LocMessage.annotatedTypeArgMissing(), errorNode);
}
Expand All @@ -14986,7 +14990,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}

return {
type: TypeBase.cloneForAnnotated(typeArgs[0].type),
type: TypeBase.cloneAsSpecialForm(typeArgs[0].type, classType),
isReadOnly: typeArgs[0].isReadOnly,
isRequired: typeArgs[0].isRequired,
isNotRequired: typeArgs[0].isNotRequired,
Expand Down Expand Up @@ -18003,6 +18007,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions

let returnType = returnTypeResult.type;

// If the type is a special form, use the special form instead.
if (returnType.specialForm) {
returnType = returnType.specialForm;
}

// If the return type includes an instance of a class with isEmptyContainer
// set, clear that because we don't want this flag to "leak" into the
// inferred return type.
Expand Down Expand Up @@ -19442,15 +19451,15 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}

case 'Annotated': {
return createAnnotatedType(errorNode, typeArgs);
return createAnnotatedType(classType, errorNode, typeArgs);
}

case 'Concatenate': {
return { type: createConcatenateType(errorNode, classType, typeArgs, flags) };
return { type: createConcatenateType(classType, errorNode, typeArgs, flags) };
}

case 'TypeGuard': {
return { type: createTypeGuardType(errorNode, classType, typeArgs, flags) };
return { type: createTypeGuardType(classType, errorNode, typeArgs, flags) };
}

case 'Unpack': {
Expand Down
7 changes: 1 addition & 6 deletions packages/pyright-internal/src/analyzer/typePrinter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import {
Type,
TypeBase,
TypeCategory,
TypeFlags,
TypeVarType,
Variance,
} from './types';
Expand Down Expand Up @@ -539,11 +538,7 @@ function printTypeInternal(
const sourceSubtypeInstance = convertToInstance(sourceSubtype);

for (const unionSubtype of type.subtypes) {
if (
isTypeSame(sourceSubtypeInstance, unionSubtype, {
typeFlagsToHonor: TypeFlags.Instance | TypeFlags.Instantiable,
})
) {
if (isTypeSame(sourceSubtypeInstance, unionSubtype)) {
if (!subtypeHandledSet.has(unionSubtypeIndex)) {
allSubtypesPreviouslyHandled = false;
}
Expand Down
30 changes: 1 addition & 29 deletions packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ export const enum TypeFlags {

// This type refers to something that has been instantiated.
Instance = 1 << 1,

// This type refers to a type that is wrapped an "Annotated"
// (PEP 593) annotation.
Annotated = 1 << 2,
}

export type UnionableType =
Expand Down Expand Up @@ -113,7 +109,6 @@ export interface TypeSameOptions {
ignorePseudoGeneric?: boolean;
ignoreTypeFlags?: boolean;
ignoreConditions?: boolean;
typeFlagsToHonor?: TypeFlags;
ignoreTypedDictNarrowEntries?: boolean;
treatAnySameAsUnknown?: boolean;
}
Expand Down Expand Up @@ -194,10 +189,6 @@ export namespace TypeBase {
return (type.flags & TypeFlags.Instance) !== 0;
}

export function isAnnotated(type: TypeBase) {
return (type.flags & TypeFlags.Annotated) !== 0;
}

export function isAmbiguous(type: TypeBase) {
return !!type.isAmbiguous;
}
Expand Down Expand Up @@ -294,12 +285,6 @@ export namespace TypeBase {
return typeClone;
}

export function cloneForAnnotated(type: Type) {
const typeClone = cloneType(type);
typeClone.flags |= TypeFlags.Annotated;
return typeClone;
}

export function cloneForCondition<T extends Type>(type: T, condition: TypeCondition[] | undefined): T {
// Handle the common case where there are no conditions. In this case,
// cloning isn't necessary.
Expand Down Expand Up @@ -2868,20 +2853,7 @@ export function isTypeSame(type1: Type, type2: Type, options: TypeSameOptions =
}

if (!options.ignoreTypeFlags) {
let type1Flags = type1.flags;
let type2Flags = type2.flags;

// Mask out the flags that we don't care about.
if (options.typeFlagsToHonor !== undefined) {
type1Flags &= options.typeFlagsToHonor;
type2Flags &= options.typeFlagsToHonor;
} else {
// By default, we don't care about the Annotated flag.
type1Flags &= ~TypeFlags.Annotated;
type2Flags &= ~TypeFlags.Annotated;
}

if (type1Flags !== type2Flags) {
if (type1.flags !== type2.flags) {
return false;
}
}
Expand Down
7 changes: 7 additions & 0 deletions packages/pyright-internal/src/tests/samples/annotated1.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,10 @@ async def func3():
x6: Annotated[str, "a" "b" "c"]
x7: Annotated[str, "a\nb"]
x8: Annotated[str, *(1, 2, 3)]


def func4():
return Annotated[int, 2 + 2]


reveal_type(func4(), expected_text="type[Annotated]")
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator4.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ test('Annotated1', () => {

configOptions.defaultPythonVersion = PythonVersion.V3_8;
const analysisResults38 = TestUtils.typeAnalyzeSampleFiles(['annotated1.py'], configOptions);
TestUtils.validateResults(analysisResults38, 5);
TestUtils.validateResults(analysisResults38, 6);

configOptions.defaultPythonVersion = PythonVersion.V3_11;
const analysisResults39 = TestUtils.typeAnalyzeSampleFiles(['annotated1.py'], configOptions);
Expand Down

0 comments on commit 0fc5682

Please sign in to comment.