Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added 3.4.2 support in a fairly formulaic fashion #1503

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
java-version: [ 8, 17 ]
scala-binary-version: [ 2.12, 2.13, 3 ]
scala-binary-version: [ 2.12, 2.13, 3.3, 3.4.2 ]
include:
# Including lowest/highest versions that work with Java 21
- java-version: 21
Expand All @@ -33,6 +33,8 @@ jobs:
scala-version: 3.3.1
- java-version: 21
scala-version: 3.3.3
- java-version: 21
scala-version: 3.4.2
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -49,7 +51,7 @@ jobs:
fail-fast: false
matrix:
java-version: [ 8, 17 ]
scala-version: [ 2.12, 2.13, 3 ]
scala-version: [ 2.12, 2.13, 3.3, 3.4.2 ]
include:
# Including lowest/highest versions that work with Java 21
- java-version: 21
Expand All @@ -64,6 +66,8 @@ jobs:
scala-version: 3.3.1
- java-version: 21
scala-version: 3.3.3
- java-version: 21
scala-version: 3.4.2

runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -95,7 +99,7 @@ jobs:
strategy:
fail-fast: false
matrix:
scala-version: [ 2.12.19, 2.13.14, 3.3.3 ]
scala-version: [ 2.12.19, 2.13.14, 3.3.3, 3.4.2 ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package dotty.ammonite.compiler

import java.io.{File => JFile}
import dotty.tools.dotc.classpath
import dotty.tools.io.{AbstractFile, PlainFile, ClassPath, ClassRepresentation, EfficientClassPath}
import classpath.FileUtils._

case class DirectoryClassPath(dir: JFile)
extends classpath.JFileDirectoryLookup[classpath.ClassFileEntryImpl]
with classpath.NoSourcePaths {
override def findClass(className: String): Option[ClassRepresentation] =
findClassFile(className) map classpath.ClassFileEntryImpl

def findClassFile(className: String): Option[AbstractFile] = {
val relativePath = classpath.FileUtils.dirPath(className)
val classFile = new JFile(dir, relativePath + ".class")
if (classFile.exists) {
val wrappedClassFile = new dotty.tools.io.File(classFile.toPath)
val abstractClassFile = new PlainFile(wrappedClassFile)
Some(abstractClassFile)
} else None
}

protected def createFileEntry(file: AbstractFile): classpath.ClassFileEntryImpl =
classpath.ClassFileEntryImpl(file)
protected def isMatchingFile(f: JFile): Boolean =
f.isClass

private[dotty] def classes(inPackage: classpath.PackageName): Seq[classpath.ClassFileEntry] =
files(inPackage)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
package ammonite.compiler

import ammonite.util.{ImportData, Imports, Name => AmmName, Printer, Util}

import dotty.tools.dotc
import dotty.tools.dotc.core.StdNames.nme
import dotc.ast.Trees._
import dotc.ast.{tpd, untpd}
import dotc.core.Flags
import dotc.core.Contexts._
import dotc.core.Names.Name
import dotc.core.Phases.Phase
import dotc.core.Symbols.{NoSymbol, Symbol, newSymbol}
import dotc.core.Types.{TermRef, Type, TypeTraverser}

import scala.collection.mutable

class AmmonitePhase(
userCodeNestingLevel: => Int,
needsUsedEarlierDefinitions: => Boolean
) extends Phase:
import tpd._

def phaseName: String = "ammonite"

private var myImports = new mutable.ListBuffer[(Boolean, String, String, Seq[AmmName])]
private var usedEarlierDefinitions0 = new mutable.ListBuffer[String]

def importData: Seq[ImportData] =
val grouped = myImports
.toList
.distinct
.groupBy { case (a, b, c, d) => (b, c, d) }
.mapValues(_.map(_._1))

val open = for {
((fromName, toName, importString), items) <- grouped
if !CompilerUtil.ignoredNames(fromName)
} yield {
val importType = items match{
case Seq(true) => ImportData.Type
case Seq(false) => ImportData.Term
case Seq(_, _) => ImportData.TermType
}

ImportData(AmmName(fromName), AmmName(toName), importString, importType)
}

open.toVector.sortBy(x => Util.encodeScalaSourcePath(x.prefix))

def usedEarlierDefinitions: Seq[String] =
usedEarlierDefinitions0.toList.distinct

private def saneSym(name: Name, sym: Symbol)(using Context): Boolean =
!name.decode.toString.contains('$') &&
sym.exists &&
// !sym.is(Flags.Synthetic) &&
!scala.util.Try(sym.is(Flags.Private)).toOption.getOrElse(true) &&
!scala.util.Try(sym.is(Flags.Protected)).toOption.getOrElse(true) &&
// sym.is(Flags.Public) &&
!CompilerUtil.ignoredSyms(sym.toString) &&
!CompilerUtil.ignoredNames(name.decode.toString)

private def saneSym(sym: Symbol)(using Context): Boolean =
saneSym(sym.name, sym)

private def processTree(t: tpd.Tree)(using Context): Unit = {
val sym = t.symbol
val name = t match {
case t: tpd.ValDef => t.name
case _ => sym.name
}
if (saneSym(name, sym)) {
val name = sym.name.decode.toString
myImports.addOne((sym.isType, name, name, Nil))
}
}

private def processImport(i: tpd.Import)(using Context): Unit = {
val expr = i.expr
val selectors = i.selectors

// Most of that logic was adapted from AmmonitePlugin, the Scala 2 counterpart
// of this file.

val prefix =
val (_ :: nameListTail, symbolHead :: _) = {
def rec(expr: tpd.Tree): List[(Name, Symbol)] = {
expr match {
case s @ tpd.Select(lhs, _) => (s.symbol.name -> s.symbol) :: rec(lhs)
case i @ tpd.Ident(name) => List(name -> i.symbol)
case t @ tpd.This(pkg) => List(pkg.name -> t.symbol)
}
}
rec(expr).reverse.unzip
}

val headFullPath = symbolHead.fullName.decode.toString.split('.')
.map(n => if (n.endsWith("$")) n.stripSuffix("$") else n) // meh
// prefix package imports with `_root_` to try and stop random
// variables from interfering with them. If someone defines a value
// called `_root_`, this will still break, but that's their problem
val rootPrefix = if(symbolHead.denot.is(Flags.Package)) Seq("_root_") else Nil
val tailPath = nameListTail.map(_.decode.toString)

(rootPrefix ++ headFullPath ++ tailPath).map(AmmName(_))

def isMask(sel: untpd.ImportSelector) = sel.name != nme.WILDCARD && sel.rename == nme.WILDCARD

val renameMap =

/**
* A map of each name importable from `expr`, to a `Seq[Boolean]`
* containing a `true` if there's a type-symbol you can import, `false`
* if there's a non-type symbol and both if there are both type and
* non-type symbols that are importable for that name
*/
val importableIsTypes =
expr.tpe
.allMembers
.map(_.symbol)
.filter(saneSym(_))
.groupBy(_.name.decode.toString)
.mapValues(_.map(_.isType).toVector)

val renamings = for{
t @ untpd.ImportSelector(name, renameTree, _) <- selectors
if !isMask(t)
// getOrElse just in case...
isType <- importableIsTypes.getOrElse(name.name.decode.toString, Nil)
rename <- Option(renameTree).collect{ case Ident(r) => r }
} yield ((isType, rename.decode.toString), name.name.decode.toString)

renamings.toMap


def isUnimportableUnlessRenamed(sym: Symbol): Boolean =
sym eq NoSymbol

@scala.annotation.tailrec
def transformImport(selectors: List[untpd.ImportSelector], sym: Symbol): List[Symbol] =
selectors match {
case Nil => Nil
case sel :: Nil if sel.isWildcard =>
if (isUnimportableUnlessRenamed(sym)) Nil
else List(sym)
case (sel @ untpd.ImportSelector(from, to, _)) :: _
if from.name == (if (from.isTerm) sym.name.toTermName else sym.name.toTypeName) =>
if (isMask(sel)) Nil
else List(
newSymbol(sym.owner, sel.rename, sym.flags, sym.info, sym.privateWithin, sym.coord)
)
case _ :: rest => transformImport(rest, sym)
}

val symNames =
for {
sym <- expr.tpe.allMembers.map(_.symbol).flatMap(transformImport(selectors, _))
if saneSym(sym)
} yield (sym.isType, sym.name.decode.toString)

val syms = for {
// For some reason `info.allImportedSymbols` does not show imported
// type aliases when they are imported directly e.g.
//
// import scala.reflect.macros.Context
//
// As opposed to via import scala.reflect.macros._.
// Thus we need to combine allImportedSymbols with the renameMap
(isType, sym) <- (symNames.toList ++ renameMap.keys).distinct
} yield (isType, renameMap.getOrElse((isType, sym), sym), sym, prefix)

myImports ++= syms
}

private def updateUsedEarlierDefinitions(
wrapperSym: Symbol,
stats: List[tpd.Tree]
)(using Context): Unit = {
/*
* We list the variables from the first wrapper
* used from the user code.
*
* E.g. if, after wrapping, the code looks like
* ```
* class cmd2 {
*
* val cmd0 = ???
* val cmd1 = ???
*
* import cmd0.{
* n
* }
*
* class Helper {
* // user-typed code
* val n0 = n + 1
* }
* }
* ```
* this would process the tree of `val n0 = n + 1`, find `n` as a tree like
* `cmd2.this.cmd0.n`, and put `cmd0` in `uses`.
*/

val typeTraverser: TypeTraverser = new TypeTraverser {
def traverse(tpe: Type) = tpe match {
case tr: TermRef if tr.prefix.typeSymbol == wrapperSym =>
tr.designator match {
case n: Name => usedEarlierDefinitions0 += n.decode.toString
case s: Symbol => usedEarlierDefinitions0 += s.name.decode.toString
case _ => // can this happen?
}
case _ =>
traverseChildren(tpe)
}
}

val traverser: TreeTraverser = new TreeTraverser {
def traverse(tree: Tree)(using Context) = tree match {
case tpd.Select(node, name) if node.symbol == wrapperSym =>
usedEarlierDefinitions0 += name.decode.toString
case tt @ tpd.TypeTree() =>
typeTraverser.traverse(tt.tpe)
case _ =>
traverseChildren(tree)
}
}

for (tree <- stats)
traverser.traverse(tree)
}

private def unpkg(tree: tpd.Tree): List[tpd.Tree] =
tree match {
case PackageDef(_, elems) => elems.flatMap(unpkg)
case _ => List(tree)
}

def run(using Context): Unit =
val elems = unpkg(ctx.compilationUnit.tpdTree)
def mainStats(trees: List[tpd.Tree]): List[tpd.Tree] =
trees
.reverseIterator
.collectFirst {
case TypeDef(name, rhs0: Template) => rhs0.body
}
.getOrElse(Nil)

val rootStats = mainStats(elems)
val stats = (1 until userCodeNestingLevel)
.foldLeft(rootStats)((trees, _) => mainStats(trees))

if (needsUsedEarlierDefinitions) {
val wrapperSym = elems.last.symbol
updateUsedEarlierDefinitions(wrapperSym, stats)
}

stats.foreach {
case i: Import => processImport(i)
case t: tpd.DefDef => processTree(t)
case t: tpd.ValDef => processTree(t)
case t: tpd.TypeDef => processTree(t)
case _ =>
}
Loading