Skip to content

Commit

Permalink
Merge pull request #656 from koxudaxi/Reduce_unnecessary_resolve_in_t…
Browse files Browse the repository at this point in the history
…ype_providers

Reduce unnecessary resolve in type providers
  • Loading branch information
koxudaxi authored Feb 27, 2023
2 parents 1f34d0d + 96b556f commit 20d346a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 110 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased
- Fix wrong inspections when a model has a __call__ method [[#655](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/655)]
- Reduce unnecessary resolve in type providers [[#656](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/656)]

## 0.3.17 - 2022-12-16
- Support Union operator [[#602](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/602)]
Expand Down
59 changes: 24 additions & 35 deletions src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.koxudaxi.pydantic

import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
import com.jetbrains.python.psi.*
Expand All @@ -20,16 +21,30 @@ import com.jetbrains.python.psi.types.*
class PydanticDataclassTypeProvider : PyTypeProviderBase() {
private val pyDataclassTypeProvider = PyDataclassTypeProvider()
private val pydanticTypeProvider = PydanticTypeProvider()
override fun getReferenceExpressionType(
referenceExpression: PyReferenceExpression,

override fun getReferenceType(
referenceTarget: PsiElement,
context: TypeEvalContext,
): PyType? {
return getPydanticDataclass(
referenceExpression,
TypeEvalContext.codeInsightFallback(referenceExpression.project)
)
}
anchor: PsiElement?
): Ref<PyType>? {
return when {
referenceTarget is PyClass && referenceTarget.isPydanticDataclass ->
getPydanticDataclassType(referenceTarget, context, anchor as? PyCallExpression, true)

referenceTarget is PyTargetExpression -> (referenceTarget as? PyTypedElement)
?.getType(context)?.pyClassTypes
?.filter { pyClassType -> pyClassType.pyClass.isPydanticDataclass }
?.firstNotNullOfOrNull { pyClassType ->
getPydanticDataclassType(
pyClassType.pyClass,
context,
anchor as? PyCallExpression,
pyClassType.isDefinition
)
}
else ->null
}?.let { Ref.create(it) }
}

internal fun getDataclassCallableType(
referenceTarget: PsiElement,
Expand All @@ -46,10 +61,9 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() {
private fun getPydanticDataclassType(
referenceTarget: PsiElement,
context: TypeEvalContext,
pyReferenceExpression: PyReferenceExpression,
callSite: PyCallExpression?,
definition: Boolean,
): PyType? {
val callSite = PyCallExpressionNavigator.getPyCallExpressionByCallee(pyReferenceExpression)
val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null

val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null
Expand All @@ -73,29 +87,4 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() {
else -> injectedDataclassType
}
}

private fun getPydanticDataclass(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {
return getResolvedPsiElements(referenceExpression, context)
.asSequence()
.mapNotNull {
when {
it is PyClass && it.isPydanticDataclass ->
getPydanticDataclassType(it, context, referenceExpression, true)

it is PyTargetExpression -> (it as? PyTypedElement)
?.getType(context)?.pyClassTypes
?.filter { pyClassType -> pyClassType.pyClass.isPydanticDataclass }
?.firstNotNullOfOrNull { pyClassType ->
getPydanticDataclassType(
pyClassType.pyClass,
context,
referenceExpression,
pyClassType.isDefinition
)
}

else -> null
}
}.firstOrNull()
}
}
107 changes: 32 additions & 75 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ import one.util.streamex.StreamEx

class PydanticTypeProvider : PyTypeProviderBase() {
private val pyTypingTypeProvider = PyTypingTypeProvider()
override fun getReferenceExpressionType(
referenceExpression: PyReferenceExpression,
context: TypeEvalContext,
): PyType? {
return getPydanticTypeForCallee(referenceExpression, context)
}

override fun getCallType(
pyFunction: PyFunction,
Expand All @@ -46,11 +40,31 @@ class PydanticTypeProvider : PyTypeProviderBase() {
context: TypeEvalContext,
anchor: PsiElement?,
): Ref<PyType>? {
if (referenceTarget !is PyTargetExpression) return null
val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
val name = referenceTarget.name ?: return null
return getRefTypeFromFieldName(name, context, pyClass)
return when {
referenceTarget is PyClass && anchor is PyCallExpression -> getPydanticTypeForClass(
referenceTarget,
context,
getInstance(anchor.project).currentInitTyped,
anchor
)
referenceTarget is PyCallExpression -> {
getPydanticDynamicModelTypeForTargetExpression(referenceTarget, context)?.pyCallableType
}
referenceTarget is PyTargetExpression -> {
val name = referenceTarget.name
if (name is String) {
val pyClass = getPyClassByAttribute(referenceTarget.parent)
?.takeIf { isPydanticModel(it, false, context) }
if (pyClass is PyClass) {
return Ref.create(getRefTypeFromFieldName(name, context, pyClass))
}
}

getPydanticDynamicModelTypeForTargetExpression(referenceTarget, context)?.let { return Ref.create(it)}
}

else -> null
}?.let { Ref.create(it) }
}

override fun getParameterType(param: PyNamedParameter, func: PyFunction, context: TypeEvalContext): Ref<PyType>? {
Expand All @@ -65,11 +79,11 @@ class PydanticTypeProvider : PyTypeProviderBase() {
param.isSelf && func.isValidatorMethod -> {
val pyClass = func.containingClass ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
Ref.create(context.getType(pyClass))
context.getType(pyClass)
}

else -> null
}
}?.let { Ref.create(it) }
}

private fun getRefTypeFromFieldNameInPyClass(
Expand All @@ -78,12 +92,12 @@ class PydanticTypeProvider : PyTypeProviderBase() {
context: TypeEvalContext,
ellipsis: PyNoneLiteralExpression,
pydanticVersion: KotlinVersion?,
): Ref<PyType>? {
): PyType? {
return pyClass.findClassAttribute(name, false, context)
?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion) }
}

private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): Ref<PyType>? {
private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): PyType? {
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()

val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
Expand All @@ -97,7 +111,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pyTargetExpression: PyTargetExpression, ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext, pyClass: PyClass,
pydanticVersion: KotlinVersion?,
): Ref<PyType>? {
): PyType? {
return fieldToParameter(
pyTargetExpression,
ellipsis,
Expand All @@ -106,8 +120,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pydanticVersion,
getConfig(pyClass, context, true),
getGenericTypeMap(pyClass, context)
)
?.let { parameter -> Ref.create(parameter.getType(context)) }
)?.getType(context)
}


Expand Down Expand Up @@ -208,62 +221,6 @@ class PydanticTypeProvider : PyTypeProviderBase() {
)?.let { Ref.create(it) }
}

private fun getPydanticTypeForCallee(
referenceExpression: PyReferenceExpression,
context: TypeEvalContext,
): PyType? {
val pyCallExpression = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null

return getResolvedPsiElements(referenceExpression, context)
.asSequence()
.mapNotNull {
when {
it is PyClass -> getPydanticTypeForClass(it, context, true, pyCallExpression)
it is PyParameter && it.isSelf ->
PsiTreeUtil.getParentOfType(it, PyFunction::class.java)
?.takeIf { pyFunction -> pyFunction.modifier == PyFunction.Modifier.CLASSMETHOD }
?.containingClass?.let { pyClass ->
getPydanticTypeForClass(
pyClass,
context,
true,
pyCallExpression
)
}

it is PyNamedParameter -> it.getArgumentType(context)?.pyClassTypes?.filter { pyClassType ->
pyClassType.isDefinition
}?.map { filteredPyClassType ->
getPydanticTypeForClass(
filteredPyClassType.pyClass,
context,
true,
pyCallExpression
)
}?.firstOrNull()

it is PyTargetExpression -> (it as? PyTypedElement)
?.let { pyTypedElement ->
context.getType(pyTypedElement)?.pyClassTypes
?.filter { pyClassType -> pyClassType.isDefinition }
?.filterNot { pyClassType -> pyClassType is PydanticDynamicModelClassType }
?.map { filteredPyClassType ->
getPydanticTypeForClass(
filteredPyClassType.pyClass,
context,
true,
pyCallExpression
)
}?.firstOrNull()
} ?: getPydanticDynamicModelTypeForTargetExpression(it, context)?.pyCallableType

else -> null
}
}
.firstOrNull()
}


private fun createConListPyType(pyCallSiteExpression: PyCallSiteExpression, context: TypeEvalContext): PyType? {
val pyCallExpression = pyCallSiteExpression as? PyCallExpression ?: return null
val argumentList = pyCallExpression.argumentList ?: return null
Expand Down Expand Up @@ -510,7 +467,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
if (isSubClassOfBaseSetting(pyClass, context)) {
getBaseSetting(pyClass, context)?.let { baseSetting ->
getBaseSettingInitParameters(baseSetting, context, typed)
?.mapNotNull { parameter -> parameter.name?.let { name -> name to parameter} }
?.mapNotNull { parameter -> parameter.name?.let { name -> name to parameter } }
?.let { collected.putAll(it) }
}
}
Expand Down

0 comments on commit 20d346a

Please sign in to comment.