Skip to content

Commit

Permalink
Merge pull request #737 from koxudaxi/support_pydantic_v2_validators_…
Browse files Browse the repository at this point in the history
…for_232

Support pydantic v2 validators for 232
  • Loading branch information
koxudaxi authored Jul 7, 2023
2 parents ca74b1e + c0365c8 commit 6918ab2
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 45 deletions.
38 changes: 34 additions & 4 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import com.jetbrains.extensions.QNameResolveContext
import com.jetbrains.extensions.resolveToElement
import com.jetbrains.python.PyNames
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.packaging.PyPackageManagers
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyStarArgumentImpl
import com.jetbrains.python.psi.impl.PyTargetExpressionImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
Expand All @@ -33,6 +33,11 @@ const val VALIDATOR_Q_NAME = "pydantic.class_validators.validator"
const val VALIDATOR_SHORT_Q_NAME = "pydantic.validator"
const val ROOT_VALIDATOR_Q_NAME = "pydantic.class_validators.root_validator"
const val ROOT_VALIDATOR_SHORT_Q_NAME = "pydantic.root_validator"
const val FIELD_VALIDATOR_Q_NAME = "pydantic.field_validator"
const val FIELD_VALIDATOR_SHORT_Q_NAME = "pydantic.functional_validators.field_validator"
const val MODEL_VALIDATOR_Q_NAME = "pydantic.model_validator"
const val MODEL_VALIDATOR_SHORT_Q_NAME = "pydantic.functional_validators.model_validator"

const val SCHEMA_Q_NAME = "pydantic.schema.Schema"
const val FIELD_Q_NAME = "pydantic.fields.Field"
const val DATACLASS_FIELD_Q_NAME = "dataclasses.field"
Expand Down Expand Up @@ -85,6 +90,14 @@ val ROOT_VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(ROOT_VALIDATO

val ROOT_VALIDATOR_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(ROOT_VALIDATOR_SHORT_Q_NAME)

val FIELD_VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(FIELD_VALIDATOR_Q_NAME)

val FIELD_VALIDATOR_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(FIELD_VALIDATOR_SHORT_Q_NAME)

val MODEL_VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(MODEL_VALIDATOR_Q_NAME)

val MODEL_VALIDATOR_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(MODEL_VALIDATOR_SHORT_Q_NAME)

val DATA_CLASS_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_Q_NAME)

val DATA_CLASS_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_SHORT_Q_NAME)
Expand All @@ -103,6 +116,17 @@ val VALIDATOR_QUALIFIED_NAMES = listOf(
ROOT_VALIDATOR_SHORT_QUALIFIED_NAME
)

val V2_VALIDATOR_QUALIFIED_NAMES = listOf(
VALIDATOR_QUALIFIED_NAME,
VALIDATOR_SHORT_QUALIFIED_NAME,
ROOT_VALIDATOR_QUALIFIED_NAME,
ROOT_VALIDATOR_SHORT_QUALIFIED_NAME,
FIELD_VALIDATOR_QUALIFIED_NAME,
FIELD_VALIDATOR_SHORT_QUALIFIED_NAME,
MODEL_VALIDATOR_QUALIFIED_NAME,
MODEL_VALIDATOR_SHORT_QUALIFIED_NAME
)

val VERSION_SPLIT_PATTERN: Pattern = Pattern.compile("[.a-zA-Z]")!!

val pydanticVersionCache: HashMap<String, KotlinVersion> = hashMapOf()
Expand Down Expand Up @@ -210,7 +234,9 @@ internal fun isDataclassMissing(pyTargetExpression: PyTargetExpression): Boolean
return pyTargetExpression.qualifiedName == DATACLASS_MISSING
}

internal val PyFunction.isValidatorMethod: Boolean get() = hasDecorator(this, VALIDATOR_QUALIFIED_NAMES)
internal fun PyFunction.isValidatorMethod(pydanticVersion: KotlinVersion?): Boolean =
hasDecorator(this, if(pydanticVersion.isV2) V2_VALIDATOR_QUALIFIED_NAMES else VALIDATOR_QUALIFIED_NAMES)



internal val PyClass.isConfigClass: Boolean get() = name == "Config"
Expand Down Expand Up @@ -406,7 +432,7 @@ fun getConfig(
pydanticVersion: KotlinVersion? = null,
): HashMap<String, Any?> {
val config = hashMapOf<String, Any?>()
val version = pydanticVersion ?: PydanticCacheService.getVersion(pyClass.project, context)
val version = pydanticVersion ?: PydanticCacheService.getVersion(pyClass.project)
getAncestorPydanticModels(pyClass, false, context)
.reversed()
.map { getConfig(it, context, false, version) }
Expand Down Expand Up @@ -661,4 +687,8 @@ fun PyCallableType.getPydanticModel(includeDataclass: Boolean, context: TypeEval


val KotlinVersion?.isV2: Boolean
get() = this?.isAtLeast(2, 0) == true
get() = this?.isAtLeast(2, 0) == true

val Sdk.pydanticVersion: String?
get() = PyPackageManagers.getInstance()
.forSdk(this).packages?.find { it.name == "pydantic" }?.version
36 changes: 19 additions & 17 deletions src/com/koxudaxi/pydantic/PydanticCacheService.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package com.koxudaxi.pydantic

import com.intellij.openapi.project.Project
import com.jetbrains.python.psi.PyStringLiteralExpression
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.impl.PyStringLiteralExpressionImpl
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.sdk.pythonSdk

class PydanticCacheService(val project: Project) {
private var version: KotlinVersion? = null
Expand All @@ -17,34 +15,35 @@ class PydanticCacheService(val project: Project) {
.filterNot { it.startsWith("__") && it.endsWith("__") }
.toSet()
}
private fun getVersion(context: TypeEvalContext): KotlinVersion? {
val version = getPsiElementByQualifiedName(VERSION_QUALIFIED_NAME, project, context) as? PyTargetExpression
?: return null
val versionString =
(version.findAssignedValue()?.lastChild?.firstChild?.nextSibling as? PyStringLiteralExpression)?.stringValue
?: (version.findAssignedValue() as? PyStringLiteralExpressionImpl)?.stringValue ?: return null
return setVersion(versionString)
private fun getVersion(): KotlinVersion? {
val sdk = project.pythonSdk ?: return null
val versionString = sdk.pydanticVersion ?: return null
return getOrPutVersionFromVersionCache(versionString)
}

private fun setVersion(version: String): KotlinVersion {
private fun getOrPutVersionFromVersionCache(version: String): KotlinVersion? {
return pydanticVersionCache.getOrPut(version) {
val versionList = version.split(VERSION_SPLIT_PATTERN).map { it.toIntOrNull() ?: 0 }
val pydanticVersion = when {
versionList.size == 1 -> KotlinVersion(versionList[0], 0)
versionList.size == 2 -> KotlinVersion(versionList[0], versionList[1])
versionList.size >= 3 -> KotlinVersion(versionList[0], versionList[1], versionList[2])
else -> null
} ?: KotlinVersion(0, 0)
} ?: return null
pydanticVersionCache[version] = pydanticVersion
pydanticVersion
}
}

private fun getOrPutVersion(context: TypeEvalContext): KotlinVersion? {
internal fun getOrPutVersion(): KotlinVersion? {
if (version != null) return version
return getVersion(context).apply { version = this }
return getVersion().apply { version = this }
}

internal fun setVersion(version: String): KotlinVersion? {
return getOrPutVersionFromVersionCache(version).also { this.version = it }
}

private fun getOrAllowedConfigKwargs(context: TypeEvalContext): Set<String>? {
if (allowedConfigKwargs != null) return allowedConfigKwargs
return getAllowedConfigKwargs(context).apply { allowedConfigKwargs = this }
Expand All @@ -55,16 +54,19 @@ class PydanticCacheService(val project: Project) {
allowedConfigKwargs = null
}

internal fun isV2(typeEvalContext: TypeEvalContext) = this.getOrPutVersion(typeEvalContext).isV2
internal val isV2 get() = this.getOrPutVersion().isV2

companion object {
fun getVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
return getInstance(project).getOrPutVersion(context)
fun getVersion(project: Project): KotlinVersion? {
return getInstance(project).getOrPutVersion()
}

fun setVersion(project: Project, version: String): KotlinVersion? {
return getInstance(project).setVersion(version)
}
fun getOrPutVersionFromVersionCache(project: Project, version: String): KotlinVersion? {
return getInstance(project).getOrPutVersionFromVersionCache(version)
}

fun getAllowedConfigKwargs(project: Project, context: TypeEvalContext): Set<String>? {
return getInstance(project).getOrAllowedConfigKwargs(context)
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class PydanticCompletionContributor : CompletionContributor() {
genericTypeMap: Map<PyGenericType, PyType>?,
withEqual: Boolean
) {
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, typeEvalContext)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
getClassVariables(pyClass, typeEvalContext)
.filter { it.name != null }
.filterNot { isUntouchedClass(it.findAssignedValue(), config, typeEvalContext) }
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PydanticIgnoreInspection : PyInspectionExtension() {
return function.containingClass?.let {
isPydanticModel(it,
true,
context) && function.isValidatorMethod
context) && function.isValidatorMethod(PydanticCacheService.getVersion(function.project))
} == true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local
}.nullize()?.toMap() ?: return null
val elementGenerator = PyElementGenerator.getInstance(project)
val ellipsis = elementGenerator.createEllipsis()
val pydanticVersion = PydanticCacheService.getVersion(project, context)
val pydanticVersion = PydanticCacheService.getVersion(project)
val fields = (listOf(pyClass) + getAncestorPydanticModels(pyClass, true, context)).flatMap {
it.classAttributes.filter { attribute -> unFilledArguments.contains(attribute.name) }
.mapNotNull { attribute -> attribute.name?.let { name -> name to attribute }}
Expand Down
8 changes: 4 additions & 4 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PydanticInspection : PyInspection() {
super.visitPyFunction(node)

if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return
if (!node.isValidatorMethod) return
if (!node.isValidatorMethod(pydanticCacheService.getOrPutVersion())) return
val paramList = node.parameterList
val params = paramList.parameters
val firstParam = params.firstOrNull()
Expand Down Expand Up @@ -87,7 +87,7 @@ class PydanticInspection : PyInspection() {
override fun visitPyClass(node: PyClass) {
super.visitPyClass(node)

if(pydanticCacheService.isV2(myTypeEvalContext)) {
if(pydanticCacheService.isV2) {
inspectCustomRootFieldV2(node)
}
inspectConfig(node)
Expand Down Expand Up @@ -217,7 +217,7 @@ class PydanticInspection : PyInspection() {
}

private fun inspectConfig(pyClass: PyClass) {
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
if (pydanticVersion?.isAtLeast(1, 8) != true) return
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
validateConfig(pyClass, myTypeEvalContext)?.forEach {
Expand All @@ -237,7 +237,7 @@ class PydanticInspection : PyInspection() {
val pyClass = pyClassType.pyClass
val attributeName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.name ?: return
val config = getConfig(pyClass, myTypeEvalContext, true)
val version = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext)
val version = PydanticCacheService.getVersion(pyClass.project)
if (config["allow_mutation"] == false || (version?.isAtLeast(1, 8) == true && config["frozen"] == true)) {
registerProblem(
node,
Expand Down
10 changes: 4 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticPackageManagerListener.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@ import com.intellij.openapi.project.ProjectManager
import com.intellij.openapi.projectRoots.Sdk
import com.intellij.openapi.util.Disposer
import com.jetbrains.python.packaging.PyPackageManager
import com.jetbrains.python.packaging.PyPackageManagers
import com.jetbrains.python.sdk.PythonSdkUtil
import com.jetbrains.python.statistics.sdks

class PydanticPackageManagerListener : PyPackageManager.Listener {
private fun updateVersion(sdk: Sdk) {
val version = PyPackageManagers.getInstance()
.forSdk(sdk).packages?.find { it.name == "pydantic" }?.version
val version = sdk.pydanticVersion
ProjectManager.getInstance().openProjects
.filter { it.sdks.contains(sdk) }
.forEach {
when (version) {
is String -> PydanticCacheService.setVersion(it, version)
else -> PydanticCacheService.clear(it)
PydanticCacheService.clear(it)
if (version is String) {
PydanticCacheService.getOrPutVersionFromVersionCache(it, version)
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class PydanticTypeProvider : PyTypeProviderBase() {
getRefTypeFromFieldName(name, context, pyClass)
}

param.isSelf && func.isValidatorMethod -> {
param.isSelf && func.isValidatorMethod(PydanticCacheService.getVersion(func.project)
) -> {
val pyClass = func.containingClass ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
context.getType(pyClass)
Expand All @@ -103,7 +104,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): PyType? {
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()

val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion)
?: getAncestorPydanticModels(pyClass, false, context).firstNotNullOfOrNull { ancestor ->
getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion)
Expand Down Expand Up @@ -298,7 +299,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
): PydanticDynamicModelClassType? {
val project = pyFunction.project
val typed = getInstance(project).currentInitTyped
val pydanticVersion = PydanticCacheService.getVersion(pyFunction.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyFunction.project)
val collected = linkedMapOf<String, PydanticDynamicModel.Attribute>()
val newVersion = pydanticVersion == null || pydanticVersion.isAtLeast(1, 5)
val modelNameParameterName = if (newVersion) "__model_name" else "model_name"
Expand Down Expand Up @@ -494,7 +495,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
}
val genericTypeMap = getGenericTypeMap(pyClass, context, pyCallExpression)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
val config = getConfig(pyClass, context, true)
for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) {
if (currentType !is PyClassType) continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import com.jetbrains.python.PythonLanguage
import com.jetbrains.python.codeInsight.PyCodeInsightSettings
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.psi.impl.PyPsiUtils
import com.jetbrains.python.psi.types.TypeEvalContext
import java.util.regex.Pattern

class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() {
Expand Down Expand Up @@ -52,7 +53,7 @@ class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() {
val defNode = maybeDef.node
if (defNode != null && defNode.elementType === PyTokenTypes.DEF_KEYWORD) {
val pyFunction = token.parent as? PyFunction ?: return Result.CONTINUE
if (!pyFunction.isValidatorMethod) return Result.CONTINUE
if (!pyFunction.isValidatorMethod(PydanticCacheService.getVersion(project))) return Result.CONTINUE
val settings = CodeStyle.getLanguageSettings(file, PythonLanguage.getInstance())
val textToType = StringBuilder()
textToType.append("(")
Expand Down
45 changes: 45 additions & 0 deletions testData/inspectionv2/validatorSelf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pydantic import BaseModel, field_validator, model_validator

def check(func):
def inner():
func()
return inner

class A(BaseModel):
a: str
b: str
c: str
d: str
e: str

@field_validator('a')
def validate_a(<weak_warning descr="Usually first parameter of such methods is named 'cls'">self</weak_warning>):
pass

@field_validator('b')
def validate_b(<weak_warning descr="Usually first parameter of such methods is named 'cls'">fles</weak_warning>):
pass

@field_validator('c')
def validate_b(*args):
pass

@field_validator('d')
def validate_c(**kwargs):
pass

@field_validator('e')
def validate_e<error descr="Method must have a first parameter, usually called 'cls'">()</error>:
pass

@model_validator()
def validate_model<error descr="Method must have a first parameter, usually called 'cls'">()</error>:
pass


def dummy(self):
pass

@check
def task(self):
pass
2 changes: 1 addition & 1 deletion testData/mock/pydanticv2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ValidationInfo,
ValidatorFunctionWrapHandler,
)

from .field_validator import field_validator, model_validator
from . import dataclasses
from .analyzed_type import AnalyzedType
from .config import BaseConfig, ConfigDict, Extra
Expand Down
15 changes: 15 additions & 0 deletions testData/mock/pydanticv2/functional_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

def field_validator(
__field: str,
*fields: str,
mode: FieldValidatorModes = 'after',
check_fields: bool | None = None,
) -> Callable[[Any], Any]:
pass


def model_validator(
*,
mode: Literal['wrap', 'before', 'after'],
) -> Any:
pass
3 changes: 3 additions & 0 deletions testSrc/com/koxudaxi/pydantic/PydanticInspectionV2Test.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ open class PydanticInspectionV2Test : PydanticInspectionBase(version = "v2") {
pydanticConfigService.mypyWarnUntypedFields = false
doTest()
}
fun testValidatorSelf() {
doTest()
}
}
Loading

0 comments on commit 6918ab2

Please sign in to comment.