Skip to content

Commit

Permalink
Make assisted injection detection more robust in code gen (#1979)
Browse files Browse the repository at this point in the history
It was previously a bit hard-coded, this makes it a bit more
runtime-dependent and flexible
  • Loading branch information
ZacSweers authored Mar 1, 2025
1 parent 90cc127 commit 9c4dada
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ internal object CircuitNames {
private const val RUNTIME_PACKAGE = "dev.zacsweers.metro"
val INJECT = ClassName(RUNTIME_PACKAGE, "Inject")
val ASSISTED = ClassName(RUNTIME_PACKAGE, "Assisted")
val ASSISTED_FACTORY = ClassName(RUNTIME_PACKAGE, "AssistedFactory")
val PROVIDER = ClassName(RUNTIME_PACKAGE, "Provider")
internal val CONTRIBUTES_INTO_SET = ClassName(RUNTIME_PACKAGE, "ContributesIntoSet")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ private class CircuitSymbolProcessor(

// If we annotated a class, check that the class isn't using assisted inject. If so, error and
// return
if (instantiationType == InstantiationType.CLASS && codegenMode != KOTLIN_INJECT_ANVIL) {
(annotatedElement as KSClassDeclaration).checkForAssistedInjection {
if (instantiationType == InstantiationType.CLASS) {
(annotatedElement as KSClassDeclaration).checkForAssistedInjection(codegenMode) {
return
}
}
Expand Down Expand Up @@ -226,12 +226,18 @@ private class CircuitSymbolProcessor(
}
}

private inline fun KSClassDeclaration.checkForAssistedInjection(exit: () -> Nothing) {
private inline fun KSClassDeclaration.checkForAssistedInjection(
codegenMode: CodegenMode,
exit: () -> Nothing,
) {
val assistedInjectClassName = codegenMode.runtime.assistedInject ?: return
// Check for an AssistedInject constructor
if (findConstructorAnnotatedWith(CircuitNames.ASSISTED_INJECT) != null) {
if (findConstructorAnnotatedWith(assistedInjectClassName) != null) {
val assistedFactory =
declarations.filterIsInstance<KSClassDeclaration>().find {
it.isAnnotationPresentWithLeniency(CircuitNames.ASSISTED_FACTORY)
declarations.filterIsInstance<KSClassDeclaration>().find { nestedClass ->
codegenMode.runtime.assistedFactory?.let {
nestedClass.isAnnotationPresentWithLeniency(it)
} == true
}
val suffix =
if (assistedFactory != null) " (${assistedFactory.qualifiedName?.asString()})" else ""
Expand Down Expand Up @@ -409,12 +415,13 @@ private class CircuitSymbolProcessor(
?.filter { it.isAnnotationPresentWithLeniency(codegenMode.runtime.assisted) }
.orEmpty()
}

val isAssisted =
if (codegenMode == KOTLIN_INJECT_ANVIL) {
assistedKSParams.isNotEmpty()
} else {
declaration.isAnnotationPresentWithLeniency(CircuitNames.ASSISTED_FACTORY)
}
assistedKSParams.isNotEmpty() ||
codegenMode.runtime.assistedFactory?.let {
declaration.isAnnotationPresentWithLeniency(it)
} == true

val creatorOrConstructor: KSFunctionDeclaration?
val targetClass: KSClassDeclaration
if (isAssisted) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ internal enum class CodegenMode {
sealed interface InjectionRuntime {
val inject: ClassName
val assisted: ClassName
val assistedInject: ClassName?
val assistedFactory: ClassName?

fun asProvider(providedType: TypeName): TypeName

Expand All @@ -266,6 +268,8 @@ internal enum class CodegenMode {
data object Javax : InjectionRuntime {
override val inject: ClassName = CircuitNames.INJECT
override val assisted: ClassName = CircuitNames.ASSISTED
override val assistedInject: ClassName = CircuitNames.ASSISTED_INJECT
override val assistedFactory: ClassName = CircuitNames.ASSISTED_FACTORY

override fun asProvider(providedType: TypeName): TypeName {
return CircuitNames.PROVIDER.parameterizedBy(providedType)
Expand All @@ -279,6 +283,8 @@ internal enum class CodegenMode {
data object KotlinInject : InjectionRuntime {
override val inject: ClassName = CircuitNames.KotlinInject.INJECT
override val assisted: ClassName = CircuitNames.KotlinInject.ASSISTED
override val assistedInject: ClassName? = null
override val assistedFactory: ClassName? = null // It has no annotation

override fun asProvider(providedType: TypeName): TypeName {
return LambdaTypeName.get(returnType = providedType)
Expand All @@ -292,6 +298,8 @@ internal enum class CodegenMode {
data object Metro : InjectionRuntime {
override val inject: ClassName = CircuitNames.Metro.INJECT
override val assisted: ClassName = CircuitNames.Metro.ASSISTED
override val assistedInject: ClassName? = null
override val assistedFactory: ClassName = CircuitNames.Metro.ASSISTED_FACTORY

override fun asProvider(providedType: TypeName): TypeName {
return CircuitNames.Metro.PROVIDER.parameterizedBy(providedType)
Expand Down

0 comments on commit 9c4dada

Please sign in to comment.