Skip to content

Commit

Permalink
Twirl enhancements (#952)
Browse files Browse the repository at this point in the history
This aims to bring the twirl module closer to parity with the twirl SBT plugin. It adds support for:

* Overriding full set of twirl template imports rather than just appending additional imports onto the default set
** Renames `twirlAdditionalImports` to `twirlImports`
** Sets the default value to `play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS`
* Overriding set of twirl extensions/formats that will be compiled
** Adds `def twirlFormats: T[Map[String, String]]` to define a mapping from extension to class name
** Sets the default value to
  ```scala
  Map(
    "html" -> "play.twirl.api.HtmlFormat",
    "xml" -> "play.twirl.api.XmlFormat",
    "js" -> "play.twirl.api.JavaScriptFormat",
    "txt" -> "play.twirl.api.TxtFormat"
  )
  ```

Commits:

* Allow overriding all twirl imports.

* Allow overriding twirl formats.

* Update twirl docs.

Pull request: #952
  • Loading branch information
mrdziuban authored Aug 31, 2020
1 parent 1a410f0 commit bbbf193
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 80 deletions.
26 changes: 11 additions & 15 deletions contrib/playlib/src/mill/playlib/Twirl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@ trait Twirl extends TwirlModule with Layout {

override def twirlSources=T.sources{ app() }

override def twirlAdditionalImports = Seq(
"_root_.play.twirl.api.TwirlFeatureImports._",
"_root_.play.twirl.api.TwirlHelperImports._",
"_root_.play.twirl.api.Html",
"_root_.play.twirl.api.JavaScript",
"_root_.play.twirl.api.Txt",
"_root_.play.twirl.api.Xml",
"models._",
"controllers._",
"play.api.i18n._",
"views.html._",
"play.api.templates.PlayMagic._",
"play.api.mvc._",
"play.api.data._"
)
override def twirlImports = T {
super.twirlImports() ++ Seq(
"models._",
"controllers._",
"play.api.i18n._",
"views.html._",
"play.api.templates.PlayMagic._",
"play.api.mvc._",
"play.api.data._"
)
}

def twirlOutput = T{Seq(compileTwirl().classes)}

Expand Down
9 changes: 7 additions & 2 deletions contrib/twirllib/src/TwirlModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ trait TwirlModule extends mill.Module {
)
}

def twirlAdditionalImports: Seq[String] = Nil
def twirlImports: T[Seq[String]] = T {
TwirlWorkerApi.twirlWorker.defaultImports(twirlClasspath().map(_.path))
}

def twirlFormats: T[Map[String, String]] = TwirlWorkerApi.twirlWorker.defaultFormats

def twirlConstructorAnnotations: Seq[String] = Nil

Expand All @@ -47,7 +51,8 @@ trait TwirlModule extends mill.Module {
twirlClasspath().map(_.path),
twirlSources().map(_.path),
T.dest,
twirlAdditionalImports,
twirlImports(),
twirlFormats(),
twirlConstructorAnnotations,
twirlCodec,
twirlInclusiveDot)
Expand Down
73 changes: 44 additions & 29 deletions contrib/twirllib/src/TwirlWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ import java.nio.charset.Charset
import mill.api.PathRef
import mill.scalalib.api.CompilationResult

import scala.jdk.CollectionConverters._
import scala.io.Codec
import scala.util.matching.Regex

class TwirlWorker {

private var twirlInstanceCache = Option.empty[(Long, TwirlWorkerApi)]
private var twirlInstanceCache = Option.empty[(Long, (TwirlWorkerApi, Class[_]))]

private def twirl(twirlClasspath: Agg[os.Path]) = {
private def twirlCompilerAndClass(twirlClasspath: Agg[os.Path]): (TwirlWorkerApi, Class[_]) = {
val classloaderSig = twirlClasspath.map(p => p.toString().hashCode + os.mtime(p)).sum
twirlInstanceCache match {
case Some((sig, instance)) if sig == classloaderSig => instance
Expand All @@ -23,7 +26,7 @@ class TwirlWorker {

// Switched to using the java api because of the hack-ish thing going on later.
//
// * we'll need to construct a collection of additional imports (will need to also include the defaults and add the user-provided additional imports)
// * we'll need to construct a collection of imports
// * we'll need to construct a collection of constructor annotations// *
// * the default collection in scala api is a Seq[String]
// * but it is defined in a different classloader (namely in cl)
Expand Down Expand Up @@ -57,26 +60,20 @@ class TwirlWorker {
cl.loadClass("scala.io.Codec"),
classOf[Boolean])

val defaultImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS")

val hashSetConstructor = hashSetClass.getConstructor(cl.loadClass("java.util.Collection"))

val instance = new TwirlWorkerApi {
override def compileTwirl(source: File,
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
additionalImports: Seq[String],
imports: Seq[String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean) {
// val defaultImports = play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS()
// val twirlAdditionalImports = new HashSet(defaultImports)
// additionalImports.foreach(twirlAdditionalImports.add)
val defaultImports = defaultImportsMethod.get(null) // unmodifiable collection
val twirlAdditionalImports = hashSetConstructor.newInstance(defaultImports).asInstanceOf[Object]
val hashSetAddMethod = twirlAdditionalImports.getClass.getMethod("add", classOf[Object])
additionalImports.foreach(hashSetAddMethod.invoke(twirlAdditionalImports, _))
// val twirlImports = new HashSet()
// imports.foreach(twirlImports.add)
val twirlImports = hashSetClass.newInstance().asInstanceOf[Object]
val hashSetAddMethod = twirlImports.getClass.getMethod("add", classOf[Object])
imports.foreach(hashSetAddMethod.invoke(twirlImports, _))

// Codec.apply(Charset.forName(codec.charSet.name()))
val twirlCodec = codecApplyMethod.invoke(null, charsetForNameMethod.invoke(null, codec.charSet.name()))
Expand All @@ -102,37 +99,56 @@ class TwirlWorker {
sourceDirectory,
generatedDirectory,
formatterType,
twirlAdditionalImports,
twirlImports,
twirlConstructorAnnotations,
twirlCodec,
Boolean.box(inclusiveDot)
)
}
}
twirlInstanceCache = Some((classloaderSig, instance))
instance
twirlInstanceCache = Some(classloaderSig -> (instance -> twirlCompilerClass))
(instance, twirlCompilerClass)
}
}

private def twirl(twirlClasspath: Agg[os.Path]): TwirlWorkerApi =
twirlCompilerAndClass(twirlClasspath)._1

private def twirlClass(twirlClasspath: Agg[os.Path]): Class[_] =
twirlCompilerAndClass(twirlClasspath)._2

def defaultImports(twirlClasspath: Agg[os.Path]): Seq[String] =
twirlClass(twirlClasspath).getField("DEFAULT_IMPORTS")
.get(null).asInstanceOf[java.util.Set[String]].asScala.toSeq

def defaultFormats: Map[String, String] =
Map(
"html" -> "play.twirl.api.HtmlFormat",
"xml" -> "play.twirl.api.XmlFormat",
"js" -> "play.twirl.api.JavaScriptFormat",
"txt" -> "play.twirl.api.TxtFormat")

def compile(twirlClasspath: Agg[os.Path],
sourceDirectories: Seq[os.Path],
dest: os.Path,
additionalImports: Seq[String],
imports: Seq[String],
formats: Map[String, String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean)
(implicit ctx: mill.api.Ctx): mill.api.Result[CompilationResult] = {
val compiler = twirl(twirlClasspath)
val formatExtsRegex = formats.keys.map(Regex.quote).mkString("|")

def compileTwirlDir(inputDir: os.Path) {
os.walk(inputDir).filter(_.last.matches(".*.scala.(html|xml|js|txt)"))
os.walk(inputDir).filter(_.last.matches(s".*.scala.($formatExtsRegex)"))
.foreach { template =>
val extFormat = twirlExtensionFormat(template.last)
val extClass = twirlExtensionClass(template.last, formats)
compiler.compileTwirl(template.toIO,
inputDir.toIO,
dest.toIO,
s"play.twirl.api.$extFormat",
additionalImports,
extClass,
imports,
constructorAnnotations,
codec,
inclusiveDot
Expand All @@ -148,19 +164,18 @@ class TwirlWorker {
mill.api.Result.Success(CompilationResult(zincFile, PathRef(classesDir)))
}

private def twirlExtensionFormat(name: String) =
if (name.endsWith("html")) "HtmlFormat"
else if (name.endsWith("xml")) "XmlFormat"
else if (name.endsWith("js")) "JavaScriptFormat"
else "TxtFormat"
private def twirlExtensionClass(name: String, formats: Map[String, String]) =
formats.collectFirst { case (ext, klass) if name.endsWith(ext) => klass }.getOrElse {
throw new IllegalStateException(s"Unknown twirl extension for file: $name. Known extensions: ${formats.keys.mkString(", ")}")
}
}

trait TwirlWorkerApi {
def compileTwirl(source: File,
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
additionalImports: Seq[String],
imports: Seq[String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 10 additions & 7 deletions contrib/twirllib/test/src/HelloWorldTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ object HelloWorldTests extends TestSuite {
object HelloWorld extends HelloBase {

object core extends HelloWorldModule {
override def twirlAdditionalImports: Seq[String] = testAdditionalImports
override def twirlImports = super.twirlImports() ++ testAdditionalImports
override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat")
override def twirlConstructorAnnotations: Seq[String] = testConstructorAnnotations
}

Expand All @@ -31,6 +32,7 @@ object HelloWorldTests extends TestSuite {

object core extends HelloWorldModule {
override def twirlInclusiveDot: Boolean = true
override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat")
}

}
Expand All @@ -51,8 +53,9 @@ object HelloWorldTests extends TestSuite {
}

def compileClassfiles: Seq[os.RelPath] = Seq[os.RelPath](
os.rel / "hello.template.scala",
os.rel / "wrapper.template.scala"
os.rel / 'html / "hello.template.scala",
os.rel / 'html / "wrapper.template.scala",
os.rel / 'svg / "test.template.scala"
)

def expectedDefaultImports: Seq[String] = Seq(
Expand Down Expand Up @@ -92,14 +95,14 @@ object HelloWorldTests extends TestSuite {

val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map(
eval.outPath / 'core / 'compileTwirl / 'dest / 'html / _
eval.outPath / 'core / 'compileTwirl / 'dest / _
)

assert(
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
outputFiles.size == 2,
outputFiles.size == 3,
evalCount > 0,
outputFiles.forall { p =>
val lines = os.read.lines(p).map(_.trim)
Expand All @@ -124,7 +127,7 @@ object HelloWorldTests extends TestSuite {

val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map( name =>
eval.outPath / 'core / 'compileTwirl / 'dest / 'html / name.toString().replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
eval.outPath / 'core / 'compileTwirl / 'dest / name / os.RelPath.up / name.last.replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
)

println(s"outputFiles: $outputFiles")
Expand All @@ -133,7 +136,7 @@ object HelloWorldTests extends TestSuite {
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
outputFiles.size == 2,
outputFiles.size == 3,
evalCount > 0,
outputFiles.filter(_.toString().contains("hello.template.scala")).forall { p =>
val lines = os.read.lines(p).map(_.trim)
Expand Down
Loading

0 comments on commit bbbf193

Please sign in to comment.