Skip to content

Commit

Permalink
Allow adding custom compile-time checks in distage-framework (#2133)
Browse files Browse the repository at this point in the history
  • Loading branch information
neko-kai authored Jun 12, 2024
1 parent 8950996 commit 773d7bc
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import izumi.fundamentals.collections.nonempty.{NEList, NEMap, NESet}
import izumi.fundamentals.platform.strings.IzString.toRichIterable
import izumi.reflect.TagK

import java.util.concurrent.TimeUnit
import scala.annotation.{nowarn, tailrec}
import scala.concurrent.duration.FiniteDuration

Expand Down Expand Up @@ -280,6 +281,19 @@ object PlanVerifier {
final def verificationPassed: Boolean = issues.isEmpty
final def verificationFailed: Boolean = issues.nonEmpty

def combine(that: PlanVerifierResult): PlanVerifierResult = {
(this, that) match {
case (PlanVerifierResult.Incorrect(Some(i1), v1, t1), PlanVerifierResult.Incorrect(Some(i2), v2, t2)) =>
PlanVerifierResult.Incorrect(Some(i1 ++ i2), v1 ++ v2, t1 + t2)
case (fail: PlanVerifierResult.Incorrect, _: PlanVerifierResult.Correct) =>
fail
case (_: PlanVerifierResult.Correct, fail: PlanVerifierResult.Incorrect) =>
fail
case (PlanVerifierResult.Correct(v1, t1), PlanVerifierResult.Correct(v2, t2)) =>
PlanVerifierResult.Correct(v1 ++ v2, t1 + t2)
}
}

final def throwOnError(): Unit = this match {
case incorrect: PlanVerifierResult.Incorrect =>
throw new PlanVerificationException(
Expand All @@ -299,5 +313,7 @@ object PlanVerifier {
object PlanVerifierResult {
final case class Incorrect(issues: Some[NESet[PlanIssue]], visitedKeys: Set[DIKey], time: FiniteDuration) extends PlanVerifierResult
final case class Correct(visitedKeys: Set[DIKey], time: FiniteDuration) extends PlanVerifierResult { override def issues: None.type = None }

def empty: PlanVerifierResult = PlanVerifierResult.Correct(Set.empty, FiniteDuration(0, TimeUnit.SECONDS))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@ import izumi.distage.InjectorFactory
import izumi.distage.config.model.AppConfig
import izumi.distage.config.model.exceptions.DIConfigReadException
import izumi.distage.constructors.TraitConstructor
import izumi.distage.framework.PlanCheck.runtime.RoleSelection
import izumi.distage.framework.PlanCheck.RoleSelection
import izumi.distage.framework.model.PlanCheckInput
import izumi.distage.framework.services.ConfigLoader
import izumi.distage.model.definition.{Binding, BootstrapModule, Id, Module, ModuleBase, ModuleDef, impl}
import izumi.distage.model.plan.Roots
import izumi.distage.model.planning.AxisPoint
import izumi.distage.model.providers.Functoid
import izumi.distage.model.reflection.SafeType
import izumi.distage.modules.DefaultModule
import izumi.distage.planning.solver.PlanVerifier
import izumi.distage.planning.solver.PlanVerifier.PlanVerifierResult
import izumi.distage.plugins.load.LoadedPlugins
import izumi.distage.roles.launcher.RoleProvider
import izumi.distage.roles.model.meta.{RoleBinding, RolesInfo}
import izumi.fundamentals.collections.nonempty.NESet
import izumi.fundamentals.platform.cli.model.raw.RawAppArgs
import izumi.fundamentals.platform.functional.Identity
import izumi.fundamentals.platform.language.Quirks
import izumi.fundamentals.platform.language.Quirks.Discarder
import izumi.logstage.api.IzLogger
import izumi.reflect.TagK
Expand All @@ -43,6 +48,21 @@ trait CheckableApp {
selectedRoles: RoleSelection,
chosenConfigFile: Option[String],
): PlanCheckInput[AppEffectType]

/**
* Override this to execute additional arbitrary user-defined checks at compile-time (or runtime via `PlanCheck.runtime`)
*
* @throws Throwable You may throw a custom exception if your check error is not describable by [[izumi.distage.model.planning.PlanIssue]]
*/
def customCheck(
planVerifier: PlanVerifier,
excludedActivations: Set[NESet[AxisPoint]],
checkConfig: Boolean,
planCheckInput: PlanCheckInput[AppEffectType],
): PlanVerifierResult = {
Quirks.discard(planVerifier, excludedActivations, checkConfig, planCheckInput)
PlanVerifierResult.empty
}
}
object CheckableApp {
type Aux[F[_]] = CheckableApp { type AppEffectType[A] = F[A] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ object PlanCheck {
planVerifier: PlanVerifier = PlanVerifier(),
logger: TrivialLogger = defaultLogger(),
): PlanCheckResult = {
val chosenRoles = parseRoles(cfg.roles)
val chosenRoles = RoleSelection.parseRoles(cfg.roles)
val chosenActivations = parseActivations(cfg.excludeActivations)
val chosenConfig = if (cfg.config == "*") None else Some(cfg.config)

Expand Down Expand Up @@ -124,13 +124,13 @@ object PlanCheck {
logger: TrivialLogger = defaultLogger(),
): PlanCheckResult = {

var effectiveRoleNames = "unknown, failed too early"
var effectiveRoots = "unknown, failed too early"
var effectiveConfig = "unknown, failed too early"
var effectiveBsPlugins = LoadedPlugins.empty
var effectiveAppPlugins = LoadedPlugins.empty
var effectiveModule = ModuleBase.empty
var effectivePlugins = LoadedPlugins.empty
var reportingEffectiveRoleNames = "unknown, failed too early"
var reportingEffectiveRoots = "unknown, failed too early"
var reportingEffectiveConfig = "unknown, failed too early"
var reportingEffectiveBsPlugins = LoadedPlugins.empty
var reportingEffectiveAppPlugins = LoadedPlugins.empty
var reportingEffectiveModule = ModuleBase.empty
var reportingEffectivePlugins = LoadedPlugins.empty

def renderPlugins(loadedPlugins: LoadedPlugins): String = {
val plugins = loadedPlugins.loaded
Expand Down Expand Up @@ -162,16 +162,16 @@ object PlanCheck {
val errorMsg = cause.fold("\n" + _.stacktraceString, _.issues.fromNESet.map(_.render + "\n").niceList())
val message = {
val configStr = if (checkConfig) {
s"\n config = ${chosenConfig.fold("*")(c => s"resource:$c")} (effective: $effectiveConfig)"
s"\n config = ${chosenConfig.fold("*")(c => s"resource:$c")} (effective: $reportingEffectiveConfig)"
} else {
""
}
val bindings = effectiveModule
val bsPlugins = effectiveBsPlugins.result
val appPlugins = effectiveAppPlugins.result
val bindings = reportingEffectiveModule
val bsPlugins = reportingEffectiveBsPlugins.result
val appPlugins = reportingEffectiveAppPlugins.result
// fixme missing DefaultModule bindings !!!
val bsPluginsStr = renderPlugins(effectiveBsPlugins)
val appPluginsStr = renderPlugins(effectiveAppPlugins)
val bsPluginsStr = renderPlugins(reportingEffectiveBsPlugins)
val appPluginsStr = renderPlugins(reportingEffectiveAppPlugins)
val printedBindings = if (printBindings) {
(if (bsPlugins.nonEmpty)
s"""
Expand Down Expand Up @@ -200,7 +200,7 @@ object PlanCheck {

s"""Found a problem with your DI wiring, when checking application=${app.getClass.getName.split('.').last.split('$').last}, with parameters:
|
| roles = $chosenRoles (effective roles: $effectiveRoleNames) (all effective roots: $effectiveRoots)
| roles = $chosenRoles (effective roles: $reportingEffectiveRoleNames) (all effective roots: $reportingEffectiveRoots)
| excludedActivations = ${NESet.from(excludedActivations).fold("ø")(_.map(_.mkString(" ")).mkString(" | "))}
| bootstrapPlugins = $bsPluginsStr
| plugins = $appPluginsStr
Expand All @@ -214,26 +214,29 @@ object PlanCheck {
|""".stripMargin
}

PlanCheckResult.Incorrect(effectivePlugins, visitedKeys, message, cause)
PlanCheckResult.Incorrect(reportingEffectivePlugins, visitedKeys, message, cause)
}

logger.log(s"Checking with roles=`$chosenRoles` excludedActivations=`$excludedActivations` chosenConfig=`$chosenConfig`")

try {
val input = app.preparePlanCheckInput(chosenRoles, chosenConfig)
val loadedPlugins = input.appPlugins ++ input.bsPlugins

effectiveRoleNames = input.roleNames.mkString(", ")
effectiveRoots = input.roots match {
reportingEffectiveRoleNames = input.roleNames.mkString(", ")
reportingEffectiveRoots = input.roots match {
case Roots.Of(roots) => roots.mkString(", ")
case Roots.Everything => "<Roots.Everything>"
}
effectiveModule = input.module
effectiveAppPlugins = input.appPlugins
effectiveBsPlugins = input.bsPlugins
val loadedPlugins = input.appPlugins ++ input.bsPlugins
effectivePlugins = loadedPlugins
reportingEffectiveModule = input.module
reportingEffectiveAppPlugins = input.appPlugins
reportingEffectiveBsPlugins = input.bsPlugins
reportingEffectivePlugins = loadedPlugins

val primaryCheckResult = checkAnyApp[F](planVerifier, excludedActivations, checkConfig, { reportingEffectiveConfig = _ }, input)
val additionalCheckResult = app.customCheck(planVerifier, excludedActivations, checkConfig, input)

checkAnyApp[F](planVerifier, excludedActivations, checkConfig, effectiveConfig = _)(input) match {
primaryCheckResult.combine(additionalCheckResult) match {
case incorrect: PlanVerifierResult.Incorrect => returnPlanCheckError(Right(incorrect))
case PlanVerifierResult.Correct(visitedKeys, _) => PlanCheckResult.Correct(loadedPlugins, visitedKeys)
}
Expand All @@ -244,12 +247,12 @@ object PlanCheck {
}
}

private[this] def checkAnyApp[F[_]](
def checkAnyApp[F[_]](
planVerifier: PlanVerifier,
excludedActivations: Set[NESet[AxisPoint]],
checkConfig: Boolean,
reportEffectiveConfig: String => Unit,
)(planCheckInput: PlanCheckInput[F]
planCheckInput: PlanCheckInput[F],
): PlanVerifierResult = {
val PlanCheckInput(effectType, module, roots, _, providedKeys, configLoader, _, _) = planCheckInput

Expand Down Expand Up @@ -296,20 +299,42 @@ object PlanCheck {
}
}

sealed trait RoleSelection {
override final def toString: String = this match {
case RoleSelection.Everything => "*"
case RoleSelection.OnlySelected(selection) => selection.mkString(" ")
case RoleSelection.AllExcluding(excluded) => excluded.map("-" + _).mkString(" ")
}
private def parseActivations(s: String): Set[NESet[AxisPoint]] = {
s.split("\\|").iterator.filter(_.nonEmpty).flatMap {
NESet `from` _.split(" ").iterator.filter(_.nonEmpty).map(AxisPoint.parseAxisPoint).toSet
}.toSet
}
object RoleSelection {
case object Everything extends RoleSelection
final case class OnlySelected(selection: Set[String]) extends RoleSelection
final case class AllExcluding(excluded: Set[String]) extends RoleSelection

private def defaultLogger(): TrivialLogger = {
TrivialLogger.make[this.type](DebugProperties.`izumi.debug.macro.distage.plancheck`.name)
}

private[this] def parseRoles(s: String): RoleSelection = {
@tailrec private def cutoffMacroTrace(t: Throwable): Unit = {
val trace = t.getStackTrace
val cutoffIdx = Some(trace.indexWhere(_.getClassName contains "scala.reflect.macros.runtime.JavaReflectionRuntimes$JavaReflectionResolvers")).filter(_ > 0)
t.setStackTrace(cutoffIdx.fold(trace)(trace.take))
val suppressed = t.getSuppressed
suppressed.foreach(cutSuppressed)
if (t.getCause ne null) cutoffMacroTrace(t.getCause)
}
// indirection for tailrec
private def cutSuppressed(t: Throwable): Unit = cutoffMacroTrace(t)

}

sealed trait RoleSelection {
override final def toString: String = this match {
case RoleSelection.Everything => "*"
case RoleSelection.OnlySelected(selection) => selection.mkString(" ")
case RoleSelection.AllExcluding(excluded) => excluded.map("-" + _).mkString(" ")
}
}
object RoleSelection {
case object Everything extends RoleSelection
final case class OnlySelected(selection: Set[String]) extends RoleSelection
final case class AllExcluding(excluded: Set[String]) extends RoleSelection

def parseRoles(s: String): RoleSelection = {
val tokens = s.split(" ").iterator.filter(_.nonEmpty).toList
tokens match {
case "*" :: Nil =>
Expand All @@ -323,7 +348,7 @@ object PlanCheck {
}
}

private[this] def throwInvalidRoleSelectionError(s: String): Nothing = {
private def throwInvalidRoleSelectionError(s: String): Nothing = {
throw new IllegalArgumentException(
s"""Invalid role selection syntax in `$s`.
|
Expand All @@ -335,28 +360,6 @@ object PlanCheck {
|""".stripMargin
)
}

private[this] def parseActivations(s: String): Set[NESet[AxisPoint]] = {
s.split("\\|").iterator.filter(_.nonEmpty).flatMap {
NESet `from` _.split(" ").iterator.filter(_.nonEmpty).map(AxisPoint.parseAxisPoint).toSet
}.toSet
}

private[this] def defaultLogger(): TrivialLogger = {
TrivialLogger.make[this.type](DebugProperties.`izumi.debug.macro.distage.plancheck`.name)
}

@tailrec private[this] def cutoffMacroTrace(t: Throwable): Unit = {
val trace = t.getStackTrace
val cutoffIdx = Some(trace.indexWhere(_.getClassName contains "scala.reflect.macros.runtime.JavaReflectionRuntimes$JavaReflectionResolvers")).filter(_ > 0)
t.setStackTrace(cutoffIdx.fold(trace)(trace.take))
val suppressed = t.getSuppressed
suppressed.foreach(cutSuppressed)
if (t.getCause ne null) cutoffMacroTrace(t.getCause)
}
// indirection for tailrec
private[this] def cutSuppressed(t: Throwable): Unit = cutoffMacroTrace(t)

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package izumi.distage.roles.test

import cats.effect.IO
import distage.DIKey
import izumi.distage.framework.model.PlanCheckInput
import izumi.distage.model.planning.AxisPoint
import izumi.distage.planning.solver.PlanVerifier
import izumi.distage.planning.solver.PlanVerifier.PlanVerifierResult
import izumi.fundamentals.collections.nonempty.NESet
import izumi.fundamentals.platform.strings.IzString.toRichIterable

object CustomCheckEntrypoint extends TestEntrypointPatchedLeakBase {
override def customCheck(
planVerifier: PlanVerifier,
excludedActivations: Set[NESet[AxisPoint]],
checkConfig: Boolean,
planCheckInput: PlanCheckInput[IO],
): PlanVerifierResult = {
val reachable = planVerifier.traceReachables(planCheckInput.module, planCheckInput.roots, planCheckInput.providedKeys, excludedActivations)
val conflictKeys = planCheckInput.module.keys.filter {
// filter out any set elements (to remove weak set elements)
case _: DIKey.SetElementKey => false
// make sure that all keys we are checking contain 'Conflict' in their short type name
case other => other.tpe.tag.shortName.contains("Conflict")
}
val unused = conflictKeys -- reachable
if (unused.nonEmpty) {
throw new RuntimeException(s"Custom check failed, found unused bindings for following keys: ${unused.map(_.tpe.tag.repr).niceList()}")
} else {
PlanVerifierResult.empty
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import izumi.fundamentals.platform.language.SourcePackageMaterializer.thisPkg
object TestEntrypoint extends TestEntrypointBase

// for `CompTimePlanCheckerTest`
object TestEntrypointPatchedLeak extends TestEntrypointBase {
object TestEntrypointPatchedLeak extends TestEntrypointPatchedLeakBase

trait TestEntrypointPatchedLeakBase extends TestEntrypointBase {
override protected def roleAppBootOverrides(argv: RoleAppMain.ArgV): Module = super.roleAppBootOverrides(argv) ++ new ModuleDef {
modify[Module].named("roleapp") {
_ ++ new ModuleDef {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import izumi.distage.framework.model.exceptions.PlanCheckException
import izumi.distage.framework.{PlanCheck, PlanCheckConfig}
import izumi.distage.model.planning.{AxisPoint, PlanIssue}
import izumi.distage.model.reflection.DIKey
import izumi.distage.roles.test.{TestEntrypoint, TestEntrypointPatchedLeak}
import izumi.distage.roles.test.{CustomCheckEntrypoint, TestEntrypoint, TestEntrypointPatchedLeak}
import izumi.fundamentals.collections.nonempty.NESet
import izumi.fundamentals.platform.language.literals.{LiteralBoolean, LiteralString}
import logstage.LogIO2
Expand Down Expand Up @@ -308,6 +308,30 @@ final class CompileTimePlanCheckerTest extends AnyWordSpec with GivenWhenThen {
assert(dep != null)
}

"Support custom checks" in {
val res = PlanCheck.runtime.checkApp(
CustomCheckEntrypoint,
PlanCheckConfig(
roles = "* -failingrole01 -failingrole02",
checkConfig = false,
excludeActivations = "mode:test",
),
)
assert(res.maybeErrorMessage.exists(_.contains("Custom check failed")))

val err = intercept[TestFailedException](assertCompiles("""
new PlanCheck.Main(
CustomCheckEntrypoint,
PlanCheckConfig(
roles = "* -failingrole01 -failingrole02",
checkConfig = false,
excludeActivations = "mode:test",
),
).assertAgainAtRuntime()
"""))
assert(err.getMessage.contains("Custom check failed"))
}

"progression test: role app fails check for excluded compound activations that are equivalent to just excluding `mode:test`" in {
val res = PlanCheck.runtime.checkApp(
TestEntrypointPatchedLeak,
Expand Down
Loading

0 comments on commit 773d7bc

Please sign in to comment.