Skip to content

Commit

Permalink
Fix Activation-specific configs being lost in config writer (#2087)
Browse files Browse the repository at this point in the history
* Bug: Activation-specific configs lost in config writer

* fix, but some cleanups might be required. Reboot removed

* wip

* wip

* wip

* wip

---------

Co-authored-by: Pavel Shirshov <pshirshov@eml.cc>
  • Loading branch information
neko-kai and pshirshov authored Mar 18, 2024
1 parent 80ed8ba commit 4595e31
Show file tree
Hide file tree
Showing 13 changed files with 499 additions and 353 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import izumi.distage.model.reflection.{DIKey, MirrorProvider}
import izumi.distage.planning.*
import izumi.distage.planning.sequential.{ForwardingRefResolverDefaultImpl, FwdrefLoopBreaker, SanityCheckerDefaultImpl}
import izumi.distage.planning.solver.SemigraphSolver.SemigraphSolverImpl
import izumi.distage.planning.solver.{GraphPreparations, PlanSolver, SemigraphSolver}
import izumi.distage.planning.solver.{GraphQueries, PlanSolver, SemigraphSolver}
import izumi.distage.provisioning.*
import izumi.distage.provisioning.strategies.*
import izumi.fundamentals.platform.functional.Identity
Expand Down Expand Up @@ -71,7 +71,7 @@ object BootstrapLocator {
val sanityChecker = new SanityCheckerDefaultImpl()
val resolver = new PlanSolver.Impl(
new SemigraphSolverImpl[DIKey, Int, InstantiationOp](),
new GraphPreparations(new BindingTranslator.Impl()),
new GraphQueries(new BindingTranslator.Impl()),
)

new PlannerDefaultImpl(
Expand Down Expand Up @@ -111,7 +111,7 @@ object BootstrapLocator {
make[MirrorProvider].fromValue(mirrorProvider)

make[PlanSolver].from[PlanSolver.Impl]
make[GraphPreparations]
make[GraphQueries]

make[SemigraphSolver[DIKey, Int, InstantiationOp]].from[SemigraphSolverImpl[DIKey, Int, InstantiationOp]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,11 @@ object SubcontextHandler {

}
}

class TracingHandler() extends SubcontextHandler[Nothing] {
override def handle(binding: Binding, c: ImplDef.ContextImpl): Either[Nothing, SingletonWiring] = {
Right(SingletonWiring.PrepareSubcontext(c.extractingFunction, Plan.empty, c.implType, c.externalKeys, Set.empty))
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package izumi.distage.planning.solver

import izumi.distage.model.definition.ModuleBase
import izumi.distage.model.definition.conflicts.{Annotated, Node}
import izumi.distage.model.plan.ExecutableOp.{CreateSet, InstantiationOp}
import izumi.distage.model.plan.Roots
import izumi.distage.model.planning.PlanIssue.*
import izumi.distage.model.planning.{ActivationChoices, AxisPoint, PlanIssue}
import izumi.distage.model.reflection.{DIKey, SafeType}
import izumi.distage.planning.SubcontextHandler
import izumi.distage.planning.solver.GenericSemigraphTraverse.{TraversalFailure, TraversalResult}
import izumi.distage.planning.solver.SemigraphSolver.SemiEdgeSeq
import izumi.distage.provisioning.strategies.ImportStrategyDefaultImpl
import izumi.functional.IzEither.*
import izumi.fundamentals.collections.IzCollections.*
import izumi.fundamentals.collections.nonempty.{NEList, NESet}
import izumi.fundamentals.collections.{ImmutableMultiMap, MutableMultiMap}
import izumi.reflect.TagK

import java.util.concurrent.TimeUnit
import scala.annotation.nowarn
import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration

object GenericSemigraphTraverse {
case class TraversalResult(visitedKeys: Set[DIKey], time: FiniteDuration, maybeIssues: Option[NESet[PlanIssue]])
case class TraversalFailure[Err](time: FiniteDuration, issues: NEList[Err])
}

abstract class GenericSemigraphTraverse[Err](
queries: GraphQueries,
subcontextHandler: SubcontextHandler[Err],
) {

def traverse[F[_]: TagK](
bindings: ModuleBase,
roots: Roots,
providedKeys: DIKey => Boolean,
excludedActivations: Set[NESet[AxisPoint]],
): Either[TraversalFailure[Err], TraversalResult] = {
val before = System.currentTimeMillis()
(for {
ops <- queries.computeOperationsUnsafe(subcontextHandler, bindings).map(_.toSeq)
} yield {
val allAxis: Map[String, Set[String]] = ops.flatMap(_._1.axis).groupBy(_.axis).map {
case (axis, points) =>
(axis, points.map(_.value).toSet)
}
val (mutators, defns) = ops.partition(_._3.isMutator)
val justOps = defns.map { case (k, op, _) => k -> op }

val setOps = queries
.computeSetsUnsafe(justOps)
.map {
case (k, (s, _)) =>
(Annotated(k, None, Set.empty), Node(s.members, s))

}.toMultimapView
.map {
case (k, v) =>
val members = v.flatMap(_.deps).toSet
(k, Node(members, v.head.meta.copy(members = members): InstantiationOp))
}
.toSeq

val opsMatrix: Seq[(Annotated[DIKey], Node[DIKey, InstantiationOp])] = queries.toDeps(justOps)

val matrix: SemiEdgeSeq[Annotated[DIKey], DIKey, InstantiationOp] = SemiEdgeSeq(opsMatrix ++ setOps)

val matrixToTrace = defns.map { case (k, op, _) => (k.key, (op, k.axis)) }.toMultimap
val justMutators = mutators.map { case (k, op, _) => (k.key, (op, k.axis)) }.toMultimap

val rootKeys: Set[DIKey] = queries.getRoots(roots, justOps)
val execOpIndex: MutableMultiMap[DIKey, InstantiationOp] = queries.executableOpIndex(matrix)

val mutVisited = mutable.HashSet.empty[(DIKey, Set[AxisPoint])]
val effectType = SafeType.getK[F]

val issues =
trace(allAxis, mutVisited, matrixToTrace, execOpIndex, justMutators, providedKeys, excludedActivations, rootKeys, effectType, bindings)

val visitedKeys: Set[DIKey] = mutVisited.map(_._1).toSet
val after = System.currentTimeMillis()
val time: FiniteDuration = FiniteDuration(after - before, TimeUnit.MILLISECONDS)

val maybeIssues: Option[NESet[PlanIssue]] = NESet.from(issues)

TraversalResult(visitedKeys, time, maybeIssues)
}).left.map {
errs => TraversalFailure(FiniteDuration(System.currentTimeMillis() - before, TimeUnit.MILLISECONDS), errs)
}
}

@nowarn("msg=Unused import")
protected[this] def trace(
allAxis: Map[String, Set[String]],
allVisited: mutable.HashSet[(DIKey, Set[AxisPoint])],
matrix: ImmutableMultiMap[DIKey, (InstantiationOp, Set[AxisPoint])],
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
justMutators: ImmutableMultiMap[DIKey, (InstantiationOp, Set[AxisPoint])],
providedKeys: DIKey => Boolean,
excludedActivations: Set[NESet[AxisPoint]],
rootKeys: Set[DIKey],
effectType: SafeType,
bindings: ModuleBase,
): Set[PlanIssue] = {
import scala.collection.compat.*

@inline def go(visited: Set[DIKey], current: Set[(DIKey, DIKey)], currentActivation: Set[AxisPoint]): RecursionResult = RecursionResult(current.iterator.map {
case (key, dependee) =>
if (visited.contains(key) || allVisited.contains((key, currentActivation))) {
Right(Iterator.empty)
} else {
@inline def reportMissing[A](key: DIKey, dependee: DIKey): Left[NEList[MissingImport], Nothing] = {
val origins = queries.allImportingBindings(matrix, currentActivation)(key, dependee)
val similarBindings = ImportStrategyDefaultImpl.findSimilarImports(bindings, key)
Left(NEList(MissingImport(key, dependee, origins, similarBindings.similarSame, similarBindings.similarSub)))
}

@inline def reportMissingIfNotProvided[A](key: DIKey, dependee: DIKey)(orElse: => Either[NEList[PlanIssue], A]): Either[NEList[PlanIssue], A] = {
if (providedKeys(key)) orElse else reportMissing(key, dependee)
}

matrix.get(key) match {
case None =>
reportMissingIfNotProvided(key, dependee)(Right(Iterator.empty))

case Some(allOps) =>
val ops = allOps.filterNot(o => queries.isIgnoredActivation(excludedActivations, o._2))
val ac = ActivationChoices(currentActivation)

val withoutCurrentActivations = {
val withoutImpossibleActivationsIter = ops.iterator.filter(ac `allValid` _._2)
withoutImpossibleActivationsIter.map {
case (op, activations) =>
(op, activations diff currentActivation, activations)
}.toSet
}

for {
// we ignore activations for set definitions
opsWithMergedSets <- {
val (setOps, otherOps) = withoutCurrentActivations.partitionMap {
case (s: CreateSet, _, _) => Left(s)
case a => Right(a)
}
for {
mergedSets <- setOps.groupBy(_.target).values.biTraverse {
ops =>
for {
members <- ops.iterator
.flatMap(_.members)
.biFlatTraverse {
memberKey =>
matrix.get(memberKey) match {
case Some(value) if value.sizeIs == 1 =>
if (ac.allValid(value.head._2)) Right(List(memberKey)) else Right(Nil)
case Some(value) =>
Left(NEList(InconsistentSetMembers(memberKey, NEList.unsafeFrom(value.iterator.map(_._1.origin.value).toList))))
case None =>
reportMissingIfNotProvided(memberKey, key)(Right(List(memberKey)))
}
}.to(Set)
} yield {
(ops.head.copy(members = members), Set.empty[AxisPoint], Set.empty[AxisPoint])
}
}
} yield otherOps ++ mergedSets
}
_ <-
verifyStep(currentActivation, providedKeys, key, dependee, reportMissing, ops, opsWithMergedSets)
next <- checkConflicts(allAxis, opsWithMergedSets, execOpIndex, excludedActivations, effectType)
} yield {
allVisited.add((key, currentActivation))

val mutators =
justMutators.getOrElse(key, Set.empty).iterator.filter(ac `allValid` _._2).flatMap(m => queries.depsOf(execOpIndex, m._1)).toSeq

val goNext = next.iterator.map {
case (nextActivation, nextDeps) =>
() =>
go(
visited = visited + key,
current = (nextDeps ++ mutators).map(_ -> key),
currentActivation = currentActivation ++ nextActivation,
)
}

goNext
}
}
}
})

// for trampoline
sealed trait RecResult {
type RecursionResult <: Iterator[Either[NEList[PlanIssue], Iterator[() => RecursionResult]]]
}
type RecursionResult = RecResult#RecursionResult
@inline def RecursionResult(a: Iterator[Either[NEList[PlanIssue], Iterator[() => RecursionResult]]]): RecursionResult = a.asInstanceOf[RecursionResult]

// trampoline
val errors = Set.newBuilder[PlanIssue]
val remainder = mutable.Stack(() => go(Set.empty, Set.from(rootKeys.map(r => r -> r)), Set.empty))

while (remainder.nonEmpty) {
val i = remainder.pop().apply()
while (i.hasNext) {
i.next() match {
case Right(nextSteps) =>
remainder pushAll nextSteps
case Left(newErrors) =>
errors ++= newErrors
}
}
}

errors.result()
}

protected def verifyStep(
currentActivation: Set[AxisPoint],
providedKeys: DIKey => Boolean,
key: DIKey,
dependee: DIKey,
reportMissing: (DIKey, DIKey) => Left[NEList[MissingImport], Nothing],
ops: Set[(InstantiationOp, Set[AxisPoint])],
opsWithMergedSets: Set[(InstantiationOp, Set[AxisPoint], Set[AxisPoint])],
): Either[NEList[PlanIssue], Unit]

protected def checkConflicts(
allAxis: Map[String, Set[String]],
withoutCurrentActivations: Set[(InstantiationOp, Set[AxisPoint], Set[AxisPoint])],
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
excludedActivations: Set[NESet[AxisPoint]],
effectType: SafeType,
): Either[NEList[PlanIssue], Seq[(Set[AxisPoint], Set[DIKey])]]

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,90 @@ import izumi.distage.model.definition.BindingTag.AxisTag
import izumi.distage.model.definition.conflicts.{Annotated, Node}
import izumi.distage.model.definition.{Binding, ModuleBase}
import izumi.distage.model.plan.ExecutableOp.{CreateSet, InstantiationOp, MonadicOp, WiringOp}
import izumi.distage.model.plan.operations.OperationOrigin
import izumi.distage.model.plan.{ExecutableOp, Roots, Wiring}
import izumi.distage.model.planning.AxisPoint
import izumi.distage.model.reflection.DIKey
import izumi.distage.model.reflection.DIKey.SetElementKey
import izumi.distage.planning.solver.SemigraphSolver.SemiEdgeSeq
import izumi.distage.planning.{BindingTranslator, SubcontextHandler}
import izumi.functional.IzEither.*
import izumi.fundamentals.collections.MutableMultiMap
import izumi.fundamentals.collections.nonempty.NEList
import izumi.fundamentals.collections.{ImmutableMultiMap, MutableMultiMap}
import izumi.fundamentals.collections.nonempty.{NEList, NESet}
import izumi.fundamentals.graphs.WeakEdge
import izumi.fundamentals.graphs.struct.IncidenceMatrix
import izumi.fundamentals.graphs.tools.gc.Tracer

import scala.annotation.nowarn

@nowarn("msg=Unused import")
class GraphPreparations(
class GraphQueries(
bindingTranslator: BindingTranslator
) {

import scala.collection.compat.*
final def isIgnoredActivation(excludedActivations: Set[NESet[AxisPoint]], activation: Set[AxisPoint]): Boolean = {
excludedActivations.exists(_ subsetOf activation)
}

final def allImportingBindings(
matrix: ImmutableMultiMap[DIKey, (InstantiationOp, Set[AxisPoint])],
currentActivation: Set[AxisPoint],
)(importedKey: DIKey,
d: DIKey,
): Set[OperationOrigin] = {
// FIXME: reuse formatting from conflictingAxisTagsHint
matrix
.getOrElse(d, Set.empty)
.collect {
case (op, activations) if activations.subsetOf(currentActivation) && (op match {
case CreateSet(_, members, _) => members
case op: ExecutableOp.WiringOp => op.wiring.requiredKeys
case op: ExecutableOp.MonadicOp => Set(op.effectKey)
}).contains(importedKey) =>
op.origin.value
}
}

def nextDepsToVisit(
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
withoutCurrentActivations: Set[(InstantiationOp, Set[AxisPoint], Set[AxisPoint])],
): Right[Nothing, Seq[(Set[AxisPoint], Set[DIKey])]] = {
val next = withoutCurrentActivations.iterator.map {
case (op, activations, _) =>
// TODO: I'm not sure if it's "correct" to "activate" all the points together but it simplifies things greatly
val deps = depsOf(execOpIndex, op)
val acts = op match {
case _: CreateSet =>
Set.empty[AxisPoint]
case _ =>
activations
}
(acts, deps)
}.toSeq
Right(next)
}

final def depsOf(
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
op: InstantiationOp,
): Set[DIKey] = {
op match {
case cs: CreateSet =>
// we completely ignore weak members, they don't make any difference in case they are unreachable through other paths
val members = cs.members.filter {
case m: SetElementKey =>
getSetElementWeakEdges(execOpIndex, m).isEmpty
case _ =>
true
}
members
case op: ExecutableOp.WiringOp =>
toDep(op).deps
case op: ExecutableOp.MonadicOp =>
toDep(op).deps
}
}

def findWeakSetMembers(
setOps: Map[Annotated[DIKey], Node[DIKey, InstantiationOp]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object PlanSolver {

@nowarn("msg=Unused import")
class Impl(
resolver: SemigraphSolver[DIKey, Int, InstantiationOp],
preps: GraphPreparations,
resolver: SemigraphSolver[DIKey, Int, InstantiationOp],
preps: GraphQueries,
) extends PlanSolver {

import scala.collection.compat.*
Expand Down
Loading

0 comments on commit 4595e31

Please sign in to comment.