Skip to content

Commit

Permalink
Inspect from_orm (#85)
Browse files Browse the repository at this point in the history
* inspect from_orm
  • Loading branch information
koxudaxi authored Oct 30, 2019
1 parent a1ba961 commit c3c3851
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 21 deletions.
50 changes: 38 additions & 12 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ val VERSION_SPLIT_PATTERN: Pattern = Pattern.compile("[.a-zA-Z]")!!

val pydanticVersionCache: HashMap<String, KotlinVersion> = hashMapOf()

val DEFAULT_CONFIG = mapOf(
"allow_population_by_alias" to "False",
"allow_population_by_field_name" to "False"
val DEFAULT_CONFIG = mapOf<String, Any?>(
"allow_population_by_alias" to false,
"allow_population_by_field_name" to false,
"orm_mode" to false
)

val CONFIG_TYPES = mapOf(
"allow_population_by_alias" to Boolean,
"allow_population_by_field_name" to Boolean,
"orm_mode" to Boolean
)

internal fun getPyClassByPyCallExpression(pyCallExpression: PyCallExpression, context: TypeEvalContext): PyClass? {
Expand Down Expand Up @@ -205,32 +212,51 @@ internal fun isValidFieldName(name: String): Boolean {
return name.first() != '_'
}

internal fun getConfigValue(name: String, value: Any?, context: TypeEvalContext): Any? {
if (value is PyReferenceExpression) {
val resolveResults = getResolveElements(value, context)
val targetExpression = PyUtil.filterTopPriorityResults(resolveResults).firstOrNull() ?: return null
val assignedValue = (targetExpression as? PyTargetExpression)?.findAssignedValue() ?: return null
return getConfigValue(name, assignedValue, context)
}
when (CONFIG_TYPES[name]) {
Boolean ->
when (value) {
is PyBoolLiteralExpression -> return value.value
is Boolean -> return value
}

}
return null
}

internal fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: Boolean): HashMap<String, String?> {
val config = hashMapOf<String, String?>()
internal fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: Boolean): HashMap<String, Any?> {
val config = hashMapOf<String, Any?>()
pyClass.getAncestorClasses(context)
.reversed()
.filter { isPydanticModel(it) }
.map { getConfig(it, context, false) }
.forEach {
it.entries.forEach { entry ->
if (entry.value != null) {
config[entry.key] = entry.value
config[entry.key] = getConfigValue(entry.key, entry.value, context)
}
}
}
pyClass.nestedClasses.firstOrNull { isConfigClass(it) }?.let {
it.classAttributes.forEach { attribute ->
attribute.findAssignedValue()?.text?.let { value ->
attribute.name?.let { name -> config[name] = value }
attribute.findAssignedValue()?.let { value ->
attribute.name?.let { name ->
config[name] = getConfigValue(name, value, context)
}
}
}
}

if (setDefault) {
DEFAULT_CONFIG.forEach { (key, value) ->
if (!config.containsKey(key)) {
config[key] = value
config[key] = getConfigValue(key, value, context)
}
}
}
Expand All @@ -239,17 +265,17 @@ internal fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: B

internal fun getFieldName(field: PyTargetExpression,
context: TypeEvalContext,
config: HashMap<String, String?>,
config: HashMap<String, Any?>,
pydanticVersion: KotlinVersion?): String? {

return if (pydanticVersion?.major == 0) {
if (config["allow_population_by_alias"] == "True") {
if (config["allow_population_by_alias"] == true) {
field.name
} else {
getAliasedFieldName(field, context, pydanticVersion)
}
} else {
if (config["allow_population_by_field_name"] == "True") {
if (config["allow_population_by_field_name"] == true) {
field.name
} else {
getAliasedFieldName(field, context, pydanticVersion)
Expand Down
12 changes: 6 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PydanticCompletionContributor : CompletionContributor() {

abstract val icon: Icon

abstract fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap<String, String?>): String
abstract fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap<String, Any?>): String

val typeProvider: PydanticTypeProvider = PydanticTypeProvider()

Expand All @@ -63,7 +63,7 @@ class PydanticCompletionContributor : CompletionContributor() {
pyTargetExpression: PyTargetExpression,
ellipsis: PyNoneLiteralExpression,
pydanticVersion: KotlinVersion?,
config: HashMap<String, String?>): String {
config: HashMap<String, Any?>): String {

val parameter = typeProvider.fieldToParameter(pyTargetExpression, ellipsis, typeEvalContext, pyClass, pydanticVersion, config)
val defaultValue = parameter?.defaultValue?.let {
Expand All @@ -81,7 +81,7 @@ class PydanticCompletionContributor : CompletionContributor() {
private fun addFieldElement(pyClass: PyClass, results: LinkedHashMap<String, LookupElement>,
typeEvalContext: TypeEvalContext,
ellipsis: PyNoneLiteralExpression,
config: HashMap<String, String?>,
config: HashMap<String, Any?>,
excludes: HashSet<String>?) {
val pydanticVersion = getPydanticVersion(pyClass.project, typeEvalContext)
getClassVariables(pyClass, typeEvalContext)
Expand All @@ -103,7 +103,7 @@ class PydanticCompletionContributor : CompletionContributor() {
protected fun addAllFieldElement(parameters: CompletionParameters, result: CompletionResultSet,
pyClass: PyClass, typeEvalContext: TypeEvalContext,
ellipsis: PyNoneLiteralExpression,
config: HashMap<String, String?>,
config: HashMap<String, Any?>,
excludes: HashSet<String>? = null) {

val newElements: LinkedHashMap<String, LookupElement> = LinkedHashMap()
Expand Down Expand Up @@ -164,7 +164,7 @@ class PydanticCompletionContributor : CompletionContributor() {
}

private object KeywordArgumentCompletionProvider : PydanticCompletionProvider() {
override fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap<String, String?>): String {
override fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap<String, Any?>): String {
return "${getFieldName(field, context, config, pydanticVersion)}="
}

Expand All @@ -189,7 +189,7 @@ class PydanticCompletionContributor : CompletionContributor() {
}

private object FieldCompletionProvider : PydanticCompletionProvider() {
override fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap<String, String?>): String {
override fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap<String, Any?>): String {
return field.name!!
}

Expand Down
26 changes: 24 additions & 2 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import com.jetbrains.python.inspections.quickfix.RenameParameterQuickFix
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyReferenceExpressionImpl
import com.jetbrains.python.psi.impl.PyStarArgumentImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.PyClassType

class PydanticInspection : PyInspection() {

Expand Down Expand Up @@ -47,11 +49,11 @@ class PydanticInspection : PyInspection() {
if (node == null) return

inspectPydanticModelCallableExpression(node)

inspectFromOrm(node)

}

private fun inspectPydanticModelCallableExpression(pyCallExpression :PyCallExpression) {
private fun inspectPydanticModelCallableExpression(pyCallExpression: PyCallExpression) {
val pyClass = getPyClassByPyCallExpression(pyCallExpression, myTypeEvalContext) ?: return
if (!isPydanticModel(pyClass, myTypeEvalContext)) return
if ((pyCallExpression.callee as? PyReferenceExpressionImpl)?.isQualified == true) return
Expand All @@ -62,5 +64,25 @@ class PydanticInspection : PyInspection() {
"class '${pyClass.name}' accepts only keyword arguments")
}
}

private fun inspectFromOrm(pyCallExpression: PyCallExpression) {
if (!pyCallExpression.isCalleeText("from_orm")) return
val resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(myTypeEvalContext)
val pyCallable = pyCallExpression.multiResolveCalleeFunction(resolveContext).firstOrNull() ?: return
if (pyCallable.asMethod()?.qualifiedName != "pydantic.main.BaseModel.from_orm") return
val typedElement = pyCallExpression.node?.firstChildNode?.firstChildNode?.psi as? PyTypedElement ?: return
val pyClass = when (val type = myTypeEvalContext.getType(typedElement)) {
is PyClass -> type
is PyClassType -> getPyClassTypeByPyTypes(type).firstOrNull { isPydanticModel(it.pyClass) }?.pyClass
else -> null
} ?: return
if (!isPydanticModel(pyClass)) return
val config = getConfig(pyClass, myTypeEvalContext, true)
if (config["orm_mode"] != true) {
registerProblem(pyCallExpression,
"You must have the config attribute orm_mode=True to use from_orm",
ProblemHighlightType.GENERIC_ERROR)
}
}
}
}
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
context: TypeEvalContext,
pyClass: PyClass,
pydanticVersion: KotlinVersion?,
config: HashMap<String, String?>): PyCallableParameter? {
config: HashMap<String, Any?>): PyCallableParameter? {
if (field.name == null || ! isValidFieldName(field.name!!)) return null
if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

Expand Down
34 changes: 34 additions & 0 deletions testData/inspection/ormMode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pydantic import BaseModel


class A(BaseModel):
pass
<error descr="You must have the config attribute orm_mode=True to use from_orm">A.from_orm('')</error>

class B(BaseModel):
class Config:
orm_mode=False
<error descr="You must have the config attribute orm_mode=True to use from_orm">B.from_orm('')</error>


class C(BaseModel):
class Config:
orm_mode=True
C.from_orm('')

class C(BaseModel):
class Config:
orm_mode=False
<error descr="You must have the config attribute orm_mode=True to use from_orm">C.from_orm('')</error>

orm_mode = True
class D(BaseModel):
class Config:
orm_mode=orm_mode
D.from_orm('')


class E(D):
class Config:
orm_mode=False
<error descr="You must have the config attribute orm_mode=True to use from_orm">E.from_orm('')</error>
3 changes: 3 additions & 0 deletions testData/mock/pydanticv1/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ class Config:

___slots__ = ()

@classmethod
def from_orm(cls, obj):
pass

class Extra(str):
allow = 'allow'
Expand Down
4 changes: 4 additions & 0 deletions testSrc/com/koxudaxi/pydantic/PydanticInspectionTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,8 @@ open class PydanticInspectionTest : PydanticInspectionBase() {
fun testRootValidatorSelf() {
doTest()
}

fun testOrmMode() {
doTest()
}
}

0 comments on commit c3c3851

Please sign in to comment.