Skip to content

Commit

Permalink
Added type enforcement for the _value_ type in an Enum class. Also …
Browse files Browse the repository at this point in the history
…added enforcement for custom `__new__` and `__init__` method signatures. This addresses microsoft#7030 and microsoft#7029. (microsoft#7044)
  • Loading branch information
erictraut authored Jan 19, 2024
1 parent 85e8de6 commit 01b1aa0
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 68 deletions.
114 changes: 100 additions & 14 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ import { getBoundCallMethod, getBoundInitMethod, getBoundNewMethod } from './con
import { Declaration, DeclarationType, isAliasDeclaration } from './declaration';
import { getNameNodeForDeclaration } from './declarationUtils';
import { deprecatedAliases, deprecatedSpecialForms } from './deprecatedSymbols';
import { isEnumClassWithMembers } from './enums';
import { getEnumDeclaredValueType, isEnumClassWithMembers } from './enums';
import { ImportResolver, ImportedModuleDescriptor, createImportedModuleDescriptor } from './importResolver';
import { ImportResult, ImportType } from './importResult';
import { getRelativeModuleName, getTopLevelImports } from './importStatementUtils';
Expand All @@ -117,7 +117,13 @@ import { Symbol } from './symbol';
import * as SymbolNameUtils from './symbolNameUtils';
import { getLastTypedDeclaredForSymbol } from './symbolUtils';
import { maxCodeComplexity } from './typeEvaluator';
import { FunctionTypeResult, MemberAccessDeprecationInfo, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
import {
FunctionArgument,
FunctionTypeResult,
MemberAccessDeprecationInfo,
TypeEvaluator,
TypeResult,
} from './typeEvaluatorTypes';
import {
getElementTypeForContainerNarrowing,
isIsinstanceFilterSubclass,
Expand Down Expand Up @@ -158,6 +164,7 @@ import {
AnyType,
ClassType,
ClassTypeFlags,
EnumLiteral,
FunctionType,
FunctionTypeFlags,
OverloadedFunctionType,
Expand Down Expand Up @@ -364,7 +371,7 @@ export class Checker extends ParseTreeWalker {

this._validateDataClassPostInit(classTypeResult.classType, node);

this._reportDuplicateEnumMembers(classTypeResult.classType);
this._validateEnumMembers(classTypeResult.classType, node);

if (ClassType.isTypedDictClass(classTypeResult.classType)) {
this._validateTypedDictClassSuite(node.suite);
Expand Down Expand Up @@ -4743,32 +4750,111 @@ export class Checker extends ParseTreeWalker {
});
}

private _reportDuplicateEnumMembers(classType: ClassType) {
// Validates that the values associated with enum members are type compatible.
// Also looks for duplicate values.
private _validateEnumMembers(classType: ClassType, node: ClassNode) {
if (!ClassType.isEnumClass(classType) || ClassType.isBuiltIn(classType)) {
return;
}

// Does the "_value_" field have a declared type? If so, we'll enforce it.
const declaredValueType = getEnumDeclaredValueType(this._evaluator, classType, /* declaredTypesOnly */ true);

// Is there a custom "__new__" and/or "__init__" method? If so, we'll
// verify that the signature of these calls is compatible with the values.
const newMemberTypeResult = getBoundNewMethod(
this._evaluator,
node.name,
classType,
MemberAccessFlags.SkipBaseClasses
);
const initMemberTypeResult = getBoundInitMethod(
this._evaluator,
node.name,
ClassType.cloneAsInstance(classType),
MemberAccessFlags.SkipBaseClasses
);

classType.details.fields.forEach((symbol, name) => {
// Enum members don't have type annotations.
if (symbol.getTypedDeclarations().length > 0) {
return;
}

const symbolType = this._evaluator.getEffectiveTypeOfSymbol(symbol);
// Is this symbol a literal instance of the enum class?
if (
!isClassInstance(symbolType) ||
!ClassType.isSameGenericClass(symbolType, classType) ||
!(symbolType.literalValue instanceof EnumLiteral)
) {
return;
}

// Look for a duplicate assignment.
const decls = symbol.getDeclarations();
if (decls.length >= 2 && decls[0].type === DeclarationType.Variable) {
const symbolType = this._evaluator.getEffectiveTypeOfSymbol(symbol);
this._evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocMessage.duplicateEnumMember().format({ name }),
decls[1].node
);

return;
}

// Is this symbol a literal instance of the enum class?
if (decls[0].type !== DeclarationType.Variable) {
return;
}

const declNode = decls[0].node;
const assignedValueType = symbolType.literalValue.itemType;
const assignmentNode = ParseTreeUtils.getParentNodeOfType<AssignmentNode>(
declNode,
ParseNodeType.Assignment
);
const errorNode = assignmentNode?.rightExpression ?? declNode;

// Validate the __new__ and __init__ methods if present.
if (newMemberTypeResult || initMemberTypeResult) {
if (!isAnyOrUnknown(assignedValueType)) {
// Construct an argument list. If the assigned type is a tuple, we'll
// unpack it. Otherwise, only one argument is passed.
const argList: FunctionArgument[] = [
{
argumentCategory:
isClassInstance(assignedValueType) && isTupleClass(assignedValueType)
? ArgumentCategory.UnpackedList
: ArgumentCategory.Simple,
typeResult: { type: assignedValueType },
},
];

if (newMemberTypeResult) {
this._evaluator.validateCallArguments(errorNode, argList, newMemberTypeResult);
}

if (initMemberTypeResult) {
this._evaluator.validateCallArguments(errorNode, argList, initMemberTypeResult);
}
}
} else if (declaredValueType) {
const diag = new DiagnosticAddendum();

// If the assigned value is already an instance of this enum class, skip this check.
if (
isClassInstance(symbolType) &&
ClassType.isSameGenericClass(symbolType, classType) &&
symbolType.literalValue !== undefined
!isClassInstance(assignedValueType) ||
!ClassType.isSameGenericClass(assignedValueType, classType)
) {
this._evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocMessage.duplicateEnumMember().format({ name }),
decls[1].node
);
if (!this._evaluator.assignType(declaredValueType, assignedValueType, diag)) {
this._evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocMessage.typeAssignmentMismatch().format(
this._evaluator.printSrcDestTypes(assignedValueType, declaredValueType)
) + diag.getString(),
errorNode
);
}
}
}
});
Expand Down
18 changes: 7 additions & 11 deletions packages/pyright-internal/src/analyzer/constructors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,13 @@ export function getBoundNewMethod(
evaluator: TypeEvaluator,
errorNode: ExpressionNode,
type: ClassType,
skipObjectBase = true
additionalFlags = MemberAccessFlags.SkipObjectBaseClass
) {
let flags =
const flags =
MemberAccessFlags.SkipClassMembers |
MemberAccessFlags.SkipAttributeAccessOverride |
MemberAccessFlags.TreatConstructorAsClassMethod;
if (skipObjectBase) {
flags |= MemberAccessFlags.SkipObjectBaseClass;
}
MemberAccessFlags.TreatConstructorAsClassMethod |
additionalFlags;

return evaluator.getTypeOfBoundMember(errorNode, type, '__new__', { method: 'get' }, /* diag */ undefined, flags);
}
Expand All @@ -79,12 +77,10 @@ export function getBoundInitMethod(
evaluator: TypeEvaluator,
errorNode: ExpressionNode,
type: ClassType,
skipObjectBase = true
additionalFlags = MemberAccessFlags.SkipObjectBaseClass
) {
let flags = MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipAttributeAccessOverride;
if (skipObjectBase) {
flags |= MemberAccessFlags.SkipObjectBaseClass;
}
const flags =
MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipAttributeAccessOverride | additionalFlags;

return evaluator.getTypeOfBoundMember(errorNode, type, '__init__', { method: 'get' }, /* diag */ undefined, flags);
}
Expand Down
92 changes: 63 additions & 29 deletions packages/pyright-internal/src/analyzer/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,30 +361,18 @@ export function transformTypeForPossibleEnumClass(

let valueType: Type;

// If the class includes a __new__ method, we cannot assume that
// the value of each enum element is simply the value assigned to it.
// The __new__ method can transform the value in ways that we cannot
// determine statically.
const newMember = lookUpClassMember(enumClassInfo.classType, '__new__', MemberAccessFlags.SkipBaseClasses);
if (newMember) {
// We may want to change this to UnknownType in the future, but
// for now, we'll leave it as Any which is consistent with the
// type specified in the Enum class definition in enum.pyi.
valueType = AnyType.create();
} else {
valueType = getValueType();

// If the LHS is an unpacked tuple, we need to handle this as
// a special case.
if (isUnpackedTuple) {
valueType =
evaluator.getTypeOfIterator(
{ type: valueType },
/* isAsync */ false,
nameNode,
/* emitNotIterableError */ false
)?.type ?? UnknownType.create();
}
valueType = getValueType();

// If the LHS is an unpacked tuple, we need to handle this as
// a special case.
if (isUnpackedTuple) {
valueType =
evaluator.getTypeOfIterator(
{ type: valueType },
/* isAsync */ false,
nameNode,
/* emitNotIterableError */ false
)?.type ?? UnknownType.create();
}

// The spec excludes descriptors.
Expand Down Expand Up @@ -425,6 +413,7 @@ export function transformTypeForPossibleEnumClass(
nameNode.value,
valueType
);

return ClassType.cloneAsInstance(ClassType.cloneWithLiteral(enumClassInfo.classType, enumLiteral));
}

Expand All @@ -442,6 +431,34 @@ export function isDeclInEnumClass(evaluator: TypeEvaluator, decl: VariableDeclar
return ClassType.isEnumClass(classInfo.classType);
}

export function getEnumDeclaredValueType(
evaluator: TypeEvaluator,
classType: ClassType,
declaredTypesOnly = false
): Type | undefined {
// See if there is a declared type for "_value_".
let valueType: Type | undefined;

const declaredValueMember = lookUpClassMember(
classType,
'_value_',
declaredTypesOnly ? MemberAccessFlags.DeclaredTypesOnly : MemberAccessFlags.Default
);

// If the declared type comes from the 'Enum' base class, ignore it
// because it will be "Any", which isn't useful to us here.
if (
declaredValueMember &&
declaredValueMember.classType &&
isClass(declaredValueMember.classType) &&
!ClassType.isBuiltIn(declaredValueMember.classType, 'Enum')
) {
valueType = evaluator.getTypeOfMember(declaredValueMember);
}

return valueType;
}

export function getTypeOfEnumMember(
evaluator: TypeEvaluator,
errorNode: ParseNode,
Expand Down Expand Up @@ -488,25 +505,42 @@ export function getTypeOfEnumMember(
}
}

// See if there is a declared type for "_value_".
const valueType = getEnumDeclaredValueType(evaluator, classType);

if (memberName === 'value' || memberName === '_value_') {
// If the enum class has a custom metaclass, it may implement some
// "magic" that computes different values for the "value" attribute.
// "magic" that computes different values for the "_value_" attribute.
// This occurs, for example, in the django TextChoices class. If we
// detect a custom metaclass, we'll assume the value is Any.
// detect a custom metaclass, we'll use the declared type of _value_
// if it is declared.
const metaclass = classType.details.effectiveMetaclass;
if (metaclass && isClass(metaclass) && !ClassType.isBuiltIn(metaclass)) {
return { type: AnyType.create(), isIncomplete };
return { type: valueType ?? AnyType.create(), isIncomplete };
}

// If the enum class has a custom __new__ or __init__ method,
// it may implement some magic that computes different values for
// the "_value_" attribute. If we see a customer __new__ or __init__,
// we'll assume the value type is what we computed above, or Any.
const newMember = lookUpClassMember(classType, '__new__', MemberAccessFlags.SkipBaseClasses);
const initMember = lookUpClassMember(classType, '__init__', MemberAccessFlags.SkipBaseClasses);
if (newMember || initMember) {
return { type: valueType ?? AnyType.create(), isIncomplete };
}

// There were no explicit assignments to the "_value_" attribute, so we can
// assume that the values are assigned directly to the "_value_" by
// the EnumMeta metaclass.
if (literalValue) {
assert(literalValue instanceof EnumLiteral);

// If there is no known value type for this literal value,
// return undefined. This will cause the caller to fall back
// on the definition of `value` within the class definition
// on the definition of "_value_" within the class definition
// (if present).
if (isAny(literalValue.itemType)) {
return undefined;
return valueType ? { type: valueType, isIncomplete } : undefined;
}

return { type: literalValue.itemType, isIncomplete };
Expand Down
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/analyzer/parseTreeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1086,12 +1086,12 @@ export function isNodeContainedWithin(node: ParseNode, potentialContainer: Parse
return false;
}

export function getParentNodeOfType(node: ParseNode, containerType: ParseNodeType): ParseNode | undefined {
export function getParentNodeOfType<T extends ParseNode>(node: ParseNode, containerType: ParseNodeType): T | undefined {
let curNode: ParseNode | undefined = node;

while (curNode) {
if (curNode.nodeType === containerType) {
return curNode;
return curNode as T;
}

curNode = curNode.parent;
Expand Down
Loading

0 comments on commit 01b1aa0

Please sign in to comment.