Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect assemblies with too many entries to fail shell script prepending #3140

Merged
merged 9 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/basic/4-builtin-commands/build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ foo.sources
foo.allSources
foo.allSourceFiles
foo.compile
foo.localClasspath
foo.finalMainClassOpt
foo.prependShellScript
foo.assembly

*/
Expand Down
85 changes: 62 additions & 23 deletions scalalib/src/mill/scalalib/Assembly.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ import scala.jdk.CollectionConverters._
import scala.tools.nsc.io.Streamable
import scala.util.Using

case class Assembly(pathRef: PathRef, addedEntries: Int)

object Assembly {

implicit val assemblyJsonRW: upickle.default.ReadWriter[Assembly] = upickle.default.macroRW

val defaultRules: Seq[Rule] = Seq(
Rule.Append("reference.conf", separator = "\n"),
Rule.Exclude(JarFile.MANIFEST_NAME),
Expand Down Expand Up @@ -195,13 +199,38 @@ object Assembly {
base: Option[os.Path] = None,
assemblyRules: Seq[Assembly.Rule] = Assembly.defaultRules
)(implicit ctx: Ctx.Dest with Ctx.Log): PathRef = {
val tmp = ctx.dest / "out-tmp.jar"
create(
destJar = ctx.dest / "out.jar",
inputPaths = inputPaths,
manifest = manifest,
prependShellScript = Option(prependShellScript).filter(_ != ""),
base = base,
assemblyRules = assemblyRules
).pathRef
}

def create(
destJar: os.Path,
inputPaths: Agg[os.Path],
manifest: mill.api.JarManifest = mill.api.JarManifest.MillDefault,
prependShellScript: Option[String] = None,
base: Option[os.Path] = None,
assemblyRules: Seq[Assembly.Rule] = Assembly.defaultRules
): Assembly = {
val rawJar = os.temp("out-tmp", deleteOnExit = false)
// we create the file later
os.remove(rawJar)

val baseUri = "jar:" + tmp.toIO.getCanonicalFile.toURI.toASCIIString
// use the `base` (the upstream assembly) as a start
val baseUri = "jar:" + rawJar.toIO.getCanonicalFile.toURI.toASCIIString
val hm = base.fold(Map("create" -> "true")) { b =>
os.copy(b, tmp)
os.copy(b, rawJar)
Map.empty
}

var addedEntryCount = 0

// Add more files by copying files to a JAR file system
Using.resource(FileSystems.newFileSystem(URI.create(baseUri), hm.asJava)) { zipFs =>
val manifestPath = zipFs.getPath(JarFile.MANIFEST_NAME)
Files.createDirectories(manifestPath.getParent)
Expand All @@ -225,37 +254,47 @@ object Assembly {
Seq(new ByteArrayInputStream(entry.separator.getBytes), inputStream())
)
val cleaned = if (Files.exists(path)) separated else separated.drop(1)
val concatenated =
new SequenceInputStream(Collections.enumeration(cleaned.asJava))
val concatenated = new SequenceInputStream(Collections.enumeration(cleaned.asJava))
addedEntryCount += 1
writeEntry(path, concatenated, append = true)
case entry: WriteOnceEntry => writeEntry(path, entry.inputStream(), append = false)
case entry: WriteOnceEntry =>
addedEntryCount += 1
writeEntry(path, entry.inputStream(), append = false)
}
}
} finally {
resourceCleaner()
}
}

val output = ctx.dest / "out.jar"
// Prepend shell script and make it executable
if (prependShellScript.isEmpty) os.move(tmp, output)
else {
val lineSep = if (!prependShellScript.endsWith("\n")) "\n\r\n" else ""
os.write(output, prependShellScript + lineSep)
os.write.append(output, os.read.inputStream(tmp))

if (!scala.util.Properties.isWin) {
os.perms.set(
output,
os.perms(output)
+ PosixFilePermission.GROUP_EXECUTE
+ PosixFilePermission.OWNER_EXECUTE
+ PosixFilePermission.OTHERS_EXECUTE
)
}
prependShellScript match {
case None =>
os.move(rawJar, destJar)
case Some(prependShellScript) =>
val lineSep = if (!prependShellScript.endsWith("\n")) "\n\r\n" else ""
val prepend = prependShellScript + lineSep
// Write the prepend-part into the final jar file
// https://en.wikipedia.org/wiki/Zip_(file_format)#Combination_with_other_file_formats
os.write(destJar, prepend)
// Append the actual JAR content
Using.resource(os.read.inputStream(rawJar)) { is =>
os.write.append(destJar, is)
}
os.remove(rawJar)

if (!scala.util.Properties.isWin) {
os.perms.set(
destJar,
os.perms(destJar)
+ PosixFilePermission.GROUP_EXECUTE
+ PosixFilePermission.OWNER_EXECUTE
+ PosixFilePermission.OTHERS_EXECUTE
)
}
}

PathRef(output)
Assembly(PathRef(destJar), addedEntryCount)
}

private def writeEntry(p: java.nio.file.Path, inputStream: InputStream, append: Boolean): Unit = {
Expand Down
66 changes: 60 additions & 6 deletions scalalib/src/mill/scalalib/JavaModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import mill.scalalib.publish.Artifact
import mill.util.Jvm
import os.{Path, ProcessOutput}

import scala.annotation.nowarn

/**
* Core configuration required to compile a single Java compilation target
*/
Expand Down Expand Up @@ -569,11 +571,31 @@ trait JavaModule
*
* This should allow much faster assembly creation in the common case where
* upstream dependencies do not change
*
* This implementation is deprecated because of it's return value.
* Please use [[upstreamAssembly2]] instead.
*/
@deprecated("Use upstreamAssembly2 instead, which has a richer return value", "Mill 0.11.8")
def upstreamAssembly: T[PathRef] = T {
Assembly.createAssembly(
upstreamAssemblyClasspath().map(_.path),
manifest(),
T.log.error(
s"upstreamAssembly target is deprecated and should no longer used." +
s" Please make sure to use upstreamAssembly2 instead."
)
upstreamAssembly2().pathRef
}

/**
* Build the assembly for upstream dependencies separate from the current
* classpath
*
* This should allow much faster assembly creation in the common case where
* upstream dependencies do not change
*/
def upstreamAssembly2: T[Assembly] = T {
Assembly.create(
destJar = T.dest / "out.jar",
inputPaths = upstreamAssemblyClasspath().map(_.path),
manifest = manifest(),
assemblyRules = assemblyRules
)
}
Expand All @@ -583,13 +605,45 @@ trait JavaModule
* classfiles from this module and all it's upstream modules and dependencies
*/
def assembly: T[PathRef] = T {
Assembly.createAssembly(
// detect potential inconsistencies due to `upstreamAssembly` deprecation after 0.11.7
if (
(upstreamAssembly.ctx.enclosing: @nowarn) != s"${classOf[JavaModule].getName}#upstreamAssembly"
) {
T.log.error(
s"${upstreamAssembly.ctx.enclosing: @nowarn} is overriding a deprecated target which is no longer used." +
s" Please make sure to override upstreamAssembly2 instead."
)
}

val prependScript = Option(prependShellScript()).filter(_ != "")
val upstream = upstreamAssembly2()

val created = Assembly.create(
destJar = T.dest / "out.jar",
Agg.from(localClasspath().map(_.path)),
manifest(),
prependShellScript(),
Some(upstreamAssembly().path),
prependScript,
Some(upstream.pathRef.path),
assemblyRules
)
// See https://github.com/com-lihaoyi/mill/pull/2655#issuecomment-1672468284
val problematicEntryCount = 65535
if (
prependScript.isDefined &&
(upstream.addedEntries + created.addedEntries) > problematicEntryCount
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's really worth keeping track of the number of entries incrementally? v.s. creating the assembly first, scanning it, and then failing if it has too many entries. If the performance of scanning the assembly is acceptable (milliseconds to tens of milliseconds?) then that would save us a bunch of book-keeping passing around addedEntries values, and a bunch of churn in replacing upstreamAssembly with upstreamAssembly2

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably depends. When we scan the jar right after it's written, it should be in the file system cache of any reasonable OS and scanning should be fast. On the other side, just remembering the added count is quasi for free, esp. when we assume, that the upstream-assembly is the larger portion of the assembly and keeps stable.

I think I want to experiment what the result looks like and how it performs.

Copy link
Member Author

@lefou lefou May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issues with scanning afterwards:

  • Scanning and counting entries is slow. Scanning the jar of test case noExe.large took almost 3 seconds (75516 entries), although we just wrote it
  • Java's ZipInputStream and JarInputStream aren't able to find any ZipEntry in a prependend jar making scanning after packaging non-trivial
  • Shelling out to some OS-installed zip tools doesn't seem right

) {
Result.Failure(
s"""The created assembly jar contains more than ${problematicEntryCount} ZIP entries.
|JARs of that size are known to not work correctly with a prepended shell script.
|Either reduce the entries count of the assembly or disable the prepended shell script with:
|
| def prependShellScript = ""
|""".stripMargin,
Some(created.pathRef)
)
} else {
Result.Success(created.pathRef)
}
}

/**
Expand Down
147 changes: 147 additions & 0 deletions scalalib/test/src/mill/scalalib/AssemblyTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package mill.scalalib

import mill._
import mill.api.Result
import mill.eval.Evaluator
import mill.util.{Jvm, TestEvaluator, TestUtil}
import utest._
import utest.framework.TestPath

import java.io.PrintStream

// Ensure the assembly is runnable, even if we have assembled lots of dependencies into it
// Reproduction of issues:
// - https://github.com/com-lihaoyi/mill/issues/528
// - https://github.com/com-lihaoyi/mill/issues/2650

object AssemblyTests extends TestSuite {

object TestCase extends TestUtil.BaseModule {
trait Setup extends ScalaModule {
def scalaVersion = "2.13.11"
def sources = T.sources(T.workspace / "src")
def ivyDeps = super.ivyDeps() ++ Agg(
ivy"com.lihaoyi::scalatags:0.8.2",
ivy"com.lihaoyi::mainargs:0.4.0",
ivy"org.apache.avro:avro:1.11.1"
)
}
trait ExtraDeps extends ScalaModule {
def ivyDeps = super.ivyDeps() ++ Agg(
ivy"dev.zio::zio:2.0.15",
ivy"org.typelevel::cats-core:2.9.0",
ivy"org.apache.spark::spark-core:3.4.0",
ivy"dev.zio::zio-metrics-connectors:2.0.8",
ivy"dev.zio::zio-http:3.0.0-RC2"
)
}

object noExe extends Module {
object small extends Setup {
override def prependShellScript: T[String] = ""
}
object large extends Setup with ExtraDeps {
override def prependShellScript: T[String] = ""
}
}

object exe extends Module {
object small extends Setup
object large extends Setup with ExtraDeps
}

}

val sources = Map(
os.rel / "src" / "Main.scala" ->
"""package ultra
|
|import scalatags.Text.all._
|import mainargs.{main, ParserForMethods}
|
|object Main {
| def generateHtml(text: String) = {
| h1(text).toString
| }
|
| @main
| def main(text: String) = {
| println(generateHtml(text))
| }
|
| def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args)
|}""".stripMargin
)

def workspaceTest[T](
m: TestUtil.BaseModule,
env: Map[String, String] = Evaluator.defaultEnv,
debug: Boolean = false,
errStream: PrintStream = System.err
)(t: TestEvaluator => T)(implicit tp: TestPath): T = {
val eval = new TestEvaluator(m, env = env, debugEnabled = debug, errStream = errStream)
os.remove.all(m.millSourcePath)
sources.foreach { case (file, content) =>
os.write(m.millSourcePath / file, content, createFolders = true)
}
os.remove.all(eval.outPath)
os.makeDir.all(m.millSourcePath / os.up)
t(eval)
}

def runAssembly(file: os.Path, wd: os.Path, checkExe: Boolean = false): Unit = {
println(s"File size: ${os.stat(file).size}")
Jvm.runSubprocess(
commandArgs = Seq(Jvm.javaExe, "-jar", file.toString(), "--text", "tutu"),
envArgs = Map.empty[String, String],
workingDir = wd
)
if (checkExe) {
Jvm.runSubprocess(
commandArgs = Seq(file.toString(), "--text", "tutu"),
envArgs = Map.empty[String, String],
workingDir = wd
)
}
}

def tests: Tests = Tests {
test("Assembly") {
test("noExe") {
test("small") {
workspaceTest(TestCase) { eval =>
val Right((res, _)) = eval(TestCase.noExe.small.assembly)
runAssembly(res.path, TestCase.millSourcePath)
}
}
test("large") {
workspaceTest(TestCase) { eval =>
val Right((res, _)) = eval(TestCase.noExe.large.assembly)
runAssembly(res.path, TestCase.millSourcePath)
}
}
}
test("exe") {
test("small") {
workspaceTest(TestCase) { eval =>
val Right((res, _)) = eval(TestCase.exe.small.assembly)
runAssembly(res.path, TestCase.millSourcePath, checkExe = true)
}
}
test("large-should-fail") {
workspaceTest(TestCase) { eval =>
val Left(Result.Failure(msg, Some(res))) = eval(TestCase.exe.large.assembly)
val expectedMsg =
"""The created assembly jar contains more than 65535 ZIP entries.
|JARs of that size are known to not work correctly with a prepended shell script.
|Either reduce the entries count of the assembly or disable the prepended shell script with:
|
| def prependShellScript = ""
|""".stripMargin
assert(msg == expectedMsg)
}
}
}
}
}
}