Skip to content

Commit

Permalink
Clean workers too from the clean command (#3579)
Browse files Browse the repository at this point in the history
This PR ensures the `clean` command works for worker tasks, and allows
users to dispose of workers by removing their corresponding metadata
file on disk (`out/foo/theWorker.json` for worker `foo.theWorker`) - in
that case, the worker instance is dropped upon first access to it.

Fixes #3276

---------

Co-authored-by: Li Haoyi <haoyi.sg@gmail.com>
  • Loading branch information
alexarchambault and lihaoyi authored Sep 26, 2024
1 parent 69243dc commit a1eb82c
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 7 deletions.
3 changes: 3 additions & 0 deletions main/define/src/mill/define/Segments.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ case class Segments private (value: Seq[Segment]) {
def ++(other: Seq[Segment]): Segments = Segments(value ++ other)
def ++(other: Segments): Segments = Segments(value ++ other.value)

def startsWith(prefix: Segments): Boolean =
value.startsWith(prefix.value)

def parts: List[String] = value.toList match {
case Nil => Nil
case Segment.Label(head) :: rest =>
Expand Down
7 changes: 7 additions & 0 deletions main/eval/src/mill/eval/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ trait Evaluator {
def outPath: os.Path
def externalOutPath: os.Path
def pathsResolver: EvaluatorPathsResolver
// TODO In 0.13.0, workerCache should have the type of mutableWorkerCache,
// while the latter should be removed
def workerCache: collection.Map[Segments, (Int, Val)]
private[mill] final def mutableWorkerCache: collection.mutable.Map[Segments, (Int, Val)] =
workerCache match {
case mut: collection.mutable.Map[Segments, (Int, Val)] => mut
case _ => sys.error("Evaluator#workerCache must be a mutable map")
}
def disableCallgraphInvalidation: Boolean = false

@deprecated(
Expand Down
14 changes: 11 additions & 3 deletions main/eval/src/mill/eval/GroupEvaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,14 @@ private[mill] trait GroupEvaluator {

val cached = loadCachedJson(logger, inputsHash, labelled, paths)

val upToDateWorker = loadUpToDateWorker(logger, inputsHash, labelled)
val upToDateWorker = loadUpToDateWorker(
logger,
inputsHash,
labelled,
forceDiscard =
// worker metadata file removed by user, let's recompute the worker
cached.isEmpty
)

upToDateWorker.map((_, inputsHash)) orElse cached.flatMap(_._2) match {
case Some((v, hashCode)) =>
Expand Down Expand Up @@ -444,7 +451,8 @@ private[mill] trait GroupEvaluator {
private def loadUpToDateWorker(
logger: ColorLogger,
inputsHash: Int,
labelled: Terminal.Labelled[_]
labelled: Terminal.Labelled[_],
forceDiscard: Boolean
): Option[Val] = {
labelled.task.asWorker
.flatMap { w =>
Expand All @@ -454,7 +462,7 @@ private[mill] trait GroupEvaluator {
}
.flatMap {
case (cachedHash, upToDate)
if cachedHash == workerCacheHash(inputsHash) =>
if cachedHash == workerCacheHash(inputsHash) && !forceDiscard =>
Some(upToDate) // worker cached and up-to-date

case (_, Val(obsolete: AutoCloseable)) =>
Expand Down
17 changes: 13 additions & 4 deletions main/src/mill/main/MainModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package mill.main

import java.util.concurrent.LinkedBlockingQueue
import mill.define.{BaseModule0, Command, NamedTask, Segments, Target, Task}
import mill.api.{Ctx, Logger, PathRef, Result}
import mill.api.{Ctx, Logger, PathRef, Result, Val}
import mill.eval.{Evaluator, EvaluatorPaths, Terminal}
import mill.resolve.{Resolve, SelectMode}
import mill.resolve.SelectMode.Separated
Expand Down Expand Up @@ -328,14 +328,14 @@ trait MainModule extends BaseModule0 {

val pathsToRemove =
if (targets.isEmpty)
Right(os.list(rootDir).filterNot(keepPath))
Right((os.list(rootDir).filterNot(keepPath), List(mill.define.Segments())))
else
mill.resolve.Resolve.Segments.resolve(
evaluator.rootModule,
targets,
SelectMode.Multi
).map { ts =>
ts.flatMap { segments =>
val allPaths = ts.flatMap { segments =>
val evPaths = EvaluatorPaths.resolveDestPaths(rootDir, segments)
val paths = Seq(evPaths.dest, evPaths.meta, evPaths.log)
val potentialModulePath = rootDir / EvaluatorPaths.makeSegmentStrings(segments)
Expand All @@ -348,12 +348,21 @@ trait MainModule extends BaseModule0 {
paths :+ potentialModulePath
} else paths
}
(allPaths, ts)
}

pathsToRemove match {
case Left(err) =>
Result.Failure(err)
case Right(paths) =>
case Right((paths, allSegments)) =>
for {
workerSegments <- evaluator.workerCache.keys.toList
if allSegments.exists(workerSegments.startsWith)
(_, Val(closeable: AutoCloseable)) <- evaluator.mutableWorkerCache.remove(workerSegments)
} {
closeable.close()
}

val existing = paths.filter(p => os.exists(p))
Target.log.debug(s"Cleaning ${existing.size} paths ...")
existing.foreach(os.remove.all)
Expand Down
161 changes: 161 additions & 0 deletions main/test/src/mill/main/MainModuleTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package mill.main
import mill.api.{PathRef, Result, Val}
import mill.{Agg, T, Task}
import mill.define.{Cross, Discover, Module}
import mill.main.client.OutFiles
import mill.testkit.UnitTester
import mill.testkit.TestBaseModule
import utest.{TestSuite, Tests, assert, test}

import java.io.{ByteArrayOutputStream, PrintStream}

import scala.collection.mutable

object MainModuleTests extends TestSuite {

object mainModule extends TestBaseModule with MainModule {
Expand Down Expand Up @@ -72,6 +75,55 @@ object MainModuleTests extends TestSuite {
}
}

class TestWorker(val name: String, workers: mutable.HashSet[TestWorker]) extends AutoCloseable {

workers.synchronized {
workers.add(this)
}

var closed = false
def close(): Unit =
if (!closed) {
workers.synchronized {
workers.remove(this)
}
closed = true
}

override def toString(): String =
s"TestWorker($name)@${Integer.toHexString(System.identityHashCode(this))}"
}

class WorkerModule(workers: mutable.HashSet[TestWorker]) extends TestBaseModule with MainModule {

trait Cleanable extends Module {
def theWorker = Task.Worker {
new TestWorker("shared", workers)
}
}

object foo extends Cleanable {
object sub extends Cleanable
}
object bar extends Cleanable {
def theWorker = Task.Worker {
new TestWorker("bar", workers)
}
}
object bazz extends Cross[Bazz]("1", "2", "3")
trait Bazz extends Cleanable with Cross.Module[String]

def all = Task {
foo.theWorker()
bar.theWorker()
bazz("1").theWorker()
bazz("2").theWorker()
bazz("3").theWorker()

()
}
}

override def tests: Tests = Tests {

test("inspect") {
Expand Down Expand Up @@ -317,5 +369,114 @@ object MainModuleTests extends TestSuite {
)
}
}

test("cleanWorker") {
test("all") {
val workers = new mutable.HashSet[TestWorker]
val workerModule = new WorkerModule(workers)
val ev = UnitTester(workerModule, null)

val r1 = ev.evaluator.evaluate(Agg(workerModule.all))
assert(r1.failing.keyCount == 0)
assert(workers.size == 5)

val r2 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator)))
assert(r2.failing.keyCount == 0)
assert(workers.isEmpty)
}

test("single-target") {
val workers = new mutable.HashSet[TestWorker]
val workerModule = new WorkerModule(workers)
val ev = UnitTester(workerModule, null)

val r1 = ev.evaluator.evaluate(Agg(workerModule.all))
assert(r1.failing.keyCount == 0)
assert(workers.size == 5)

val r2 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator, "foo.theWorker")))
assert(r2.failing.keyCount == 0)
assert(workers.size == 4)

val r3 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator, "bar.theWorker")))
assert(r3.failing.keyCount == 0)
assert(workers.size == 3)

val r4 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator, "bazz[1].theWorker")))
assert(r4.failing.keyCount == 0)
assert(workers.size == 2)
}

test("single-target via rm") {
val workers = new mutable.HashSet[TestWorker]
val workerModule = new WorkerModule(workers)
val ev = UnitTester(workerModule, null)

ev.evaluator.evaluate(Agg(workerModule.foo.theWorker))
.ensuring(_.failing.keyCount == 0)
assert(workers.size == 1)

val originalFooWorker = workers.head

ev.evaluator.evaluate(Agg(workerModule.bar.theWorker))
.ensuring(_.failing.keyCount == 0)
assert(workers.size == 2)
assert(workers.exists(_ eq originalFooWorker))

val originalBarWorker = workers.filter(_ ne originalFooWorker).head

ev.evaluator.evaluate(Agg(workerModule.foo.theWorker))
.ensuring(_.failing.keyCount == 0)
assert(workers.size == 2)
assert(workers.exists(_ eq originalFooWorker))

ev.evaluator.evaluate(Agg(workerModule.bar.theWorker))
.ensuring(_.failing.keyCount == 0)
assert(workers.size == 2)
assert(workers.exists(_ eq originalBarWorker))

val outDir = os.Path(OutFiles.out, workerModule.millSourcePath)

assert(!originalFooWorker.closed)
os.remove(outDir / "foo/theWorker.json")

ev.evaluator.evaluate(Agg(workerModule.foo.theWorker))
.ensuring(_.failing.keyCount == 0)
assert(workers.size == 2)
assert(!workers.exists(_ eq originalFooWorker))
assert(originalFooWorker.closed)

assert(!originalBarWorker.closed)
os.remove(outDir / "bar/theWorker.json")

ev.evaluator.evaluate(Agg(workerModule.bar.theWorker))
.ensuring(_.failing.keyCount == 0)
assert(workers.size == 2)
assert(!workers.exists(_ eq originalBarWorker))
assert(originalBarWorker.closed)
}

test("single-module") {
val workers = new mutable.HashSet[TestWorker]
val workerModule = new WorkerModule(workers)
val ev = UnitTester(workerModule, null)

val r1 = ev.evaluator.evaluate(Agg(workerModule.all))
assert(r1.failing.keyCount == 0)
assert(workers.size == 5)

val r2 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator, "foo")))
assert(r2.failing.keyCount == 0)
assert(workers.size == 4)

val r3 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator, "bar")))
assert(r3.failing.keyCount == 0)
assert(workers.size == 3)

val r4 = ev.evaluator.evaluate(Agg(workerModule.clean(ev.evaluator, "bazz[1]")))
assert(r4.failing.keyCount == 0)
assert(workers.size == 2)
}
}
}
}

0 comments on commit a1eb82c

Please sign in to comment.