diff --git a/metals/src/main/scala/scala/meta/internal/metals/CodeActionProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/CodeActionProvider.scala index 4ba9adcd539..389b1875a56 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/CodeActionProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/CodeActionProvider.scala @@ -44,7 +44,8 @@ final class CodeActionProvider( new PatternMatchRefactor(trees), new RewriteBracesParensCodeAction(trees), new ExtractValueCodeAction(trees, buffers), - new CreateCompanionObjectCodeAction(trees, buffers) + new CreateCompanionObjectCodeAction(trees, buffers), + new FlatMapToForComprehensionCodeAction(trees, buffers) ) def codeActions( diff --git a/metals/src/main/scala/scala/meta/internal/metals/codeactions/FlatMapToForComprehensionCodeAction.scala b/metals/src/main/scala/scala/meta/internal/metals/codeactions/FlatMapToForComprehensionCodeAction.scala new file mode 100644 index 00000000000..b1eb145d9ed --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/codeactions/FlatMapToForComprehensionCodeAction.scala @@ -0,0 +1,502 @@ +package scala.meta.internal.metals.codeactions + +import scala.concurrent.ExecutionContext +import scala.concurrent.Future + +import scala.meta.Enumerator +import scala.meta.Name +import scala.meta.Pat +import scala.meta.Term +import scala.meta.Tree +import scala.meta.inputs.Position +import scala.meta.internal.metals.Buffers +import scala.meta.internal.metals.CodeAction +import scala.meta.internal.metals.MetalsEnrichments._ +import scala.meta.internal.parsing.Trees +import scala.meta.io.AbsolutePath +import scala.meta.pc.CancelToken + +import org.eclipse.lsp4j.CodeActionParams +import org.eclipse.{lsp4j => l} + +class FlatMapToForComprehensionCodeAction( + trees: Trees, + buffers: Buffers +) extends CodeAction { + override def kind: String = l.CodeActionKind.RefactorRewrite + + override def contribute(params: CodeActionParams, token: CancelToken)(implicit + ec: ExecutionContext + ): Future[Seq[l.CodeAction]] = Future { + val uri = params.getTextDocument().getUri() + + val path = uri.toAbsolutePath + val range = params.getRange() + val maybeTree = + if (range.getStart == range.getEnd) + trees + .findLastEnclosingAt[Tree]( + path, + range.getStart(), + isTreeInteresting + ) + else + None + + def parse(code: String): Option[Tree] = trees.parse(path)(code).toOption + + def parseTerm(code: String): Option[Term] = + trees.parseTerm(path)(code).toOption + + val maybeChainedCodeAction = for { + document <- buffers.get(path) + applyTree <- maybeTree + indentation = getIndentForPos(applyTree.pos, document) + } yield { + applyTree match { + case termApply: Term.Apply => + codeActionWithApply( + parseTerm, + parse, + path, + termApply, + document, + indentation + ) + case termSelect: Term.Select => + codeActionWithSelect( + parseTerm, + parse, + path, + termSelect, + document, + indentation + ) + + case termName: Term.Name => + codeActionWithName( + parseTerm, + parse, + path, + termName, + document, + indentation + ) + + case _ => None + } + } + + maybeChainedCodeAction.flatten.toSeq + + } + + private def constructCodeAction( + forElementsList: List[Enumerator], + yieldTerm: Term, + indentation: String, + path: AbsolutePath, + startPos: l.Position, + endPos: l.Position + ): l.CodeAction = { + + val forYieldIndentedString = Term.ForYield + .apply(forElementsList, yieldTerm) + .syntax + .split(Array('\n')) + .map(line => s"$indentation $line") + .mkString("\n") + + val forYieldString = + s"""|{ + |$forYieldIndentedString + |$indentation}""".stripMargin + + val codeAction = new l.CodeAction() + val range = + new l.Range(startPos, endPos) + codeAction.setTitle( + FlatMapToForComprehensionCodeAction.flatMapToForComprehension + ) + codeAction.setKind(this.kind) + val forComprehensionTextEdit = new l.TextEdit(range, forYieldString) + codeAction.setEdit( + new l.WorkspaceEdit( + Map( + path.toURI.toString -> List(forComprehensionTextEdit).asJava + ).asJava + ) + ) + codeAction + } + + private def codeActionWithApply( + parseTerm: String => Option[Term], + parse: String => Option[Tree], + path: AbsolutePath, + termApply: Term.Apply, + document: String, + indentation: String + ): Option[l.CodeAction] = { + val (forElements, maybeYieldTerm, generatedByMetalsVals) = { + extractChainedForYield( + parseTerm, + parse, + None, + None, + List.empty, + termApply, + document, + Set.empty + ) + } + + if (forElements.nonEmpty) { + maybeYieldTerm.map { yieldTerm => + constructCodeAction( + forElements, + yieldTerm, + indentation, + path, + termApply.pos.toLSP.getStart, + termApply.pos.toLSP.getEnd + ) + } + } else None + + } + + private def codeActionWithSelect( + parseTerm: String => Option[Term], + parse: String => Option[Tree], + path: AbsolutePath, + termSelect: Term.Select, + document: String, + indentation: String + ): Option[l.CodeAction] = { + termSelect.parent.collect { case termApply: Term.Apply => + codeActionWithApply( + parseTerm, + parse, + path, + termApply, + document, + indentation + ) + }.flatten + } + + private def codeActionWithName( + parseTerm: String => Option[Term], + parse: String => Option[Tree], + path: AbsolutePath, + termName: Term.Name, + document: String, + indentation: String + ): Option[l.CodeAction] = { + termName.parent.collect { case termSelect: Term.Select => + codeActionWithSelect( + parseTerm, + parse, + path, + termSelect, + document, + indentation + ) + }.flatten + + } + + private def replacePlaceHolder(tree: Term, newName: Term.Name): Term = + tree match { + case Term.Apply(fun, args) => + Term.Apply( + replacePlaceHolder(fun, newName), + args.map(replacePlaceHolder(_, newName)) + ) + case Term.Select(qual, name) => + Term.Select(replacePlaceHolder(qual, newName), name) + case Term.Placeholder() => newName + case other => other + } + + private def processValueNameAndNextQual( + tree: Tree, + generatedByMetalsValues: Set[String] + ): (Option[(String, Term)], Set[String]) = { + tree match { + case Term.Function(List(param), term) if param.name.value.isEmpty => + val (newName, generatedValues) = + createNewName(tree, generatedByMetalsValues) + (Some((newName, term)), generatedValues) + case Term.Function(List(param), term) => + (Some((param.name.value, term)), generatedByMetalsValues) + + case Term.AnonymousFunction(term) => + val (newName, generatedValues) = + createNewName(tree, generatedByMetalsValues) + ( + Some((newName, replacePlaceHolder(term, Term.Name(newName)))), + generatedValues + ) + case Term.Block(List(function)) => + processValueNameAndNextQual(function, generatedByMetalsValues) + case term: Term => + val (newName, generatedValues) = + createNewName(tree, generatedByMetalsValues) + ( + Some((newName, Term.Apply(term, List(Term.Name(newName))))), + generatedValues + ) + } + } + + private def extractChainedForYield( + parseTerm: String => Option[Term], + parse: String => Option[Tree], + perhapseLastName: Option[String], + currentYieldTerm: Option[Term], + existingForElements: List[Enumerator], + termApply: Term.Apply, + document: String, + generatedByMetalsValues: Set[String] + ): (List[Enumerator], Option[Term], Set[String]) = { + val (perhapseValueNameAndNextQual, newMetalsNames) = + processValueNameAndNextQual( + termApply.args.head, + generatedByMetalsValues + ) + + termApply.fun match { + case termSelect: Term.Select + if termSelect.name.value == "flatMap" || termSelect.name.value == "map" => + val qual = termSelect.qual + val assignOrMap = + if (termSelect.name.value == "flatMap") AssignOrMap.map + else AssignOrMap.assign + + val result = for { + valueName <- perhapseValueNameAndNextQual.map(_._1) + nextQual <- perhapseValueNameAndNextQual.map(_._2) + } yield { + val (elems, maybeYieldTerm, updatedGenByMetalsVals) = + perhapseLastName match { + case Some(lastName) => + ( + List(assignOrMap match { + case AssignOrMap.assign => + Enumerator.Val.apply( + Pat.Var.apply(Term.Name.apply(lastName)), + nextQual + ) + case AssignOrMap.map => + Enumerator.Generator.apply( + Pat.Var.apply(Term.Name.apply(lastName)), + nextQual + ) + }) ++ existingForElements, + currentYieldTerm, + newMetalsNames + ) + case None => + assignOrMap match { + case AssignOrMap.assign => + (existingForElements, Some(nextQual), newMetalsNames) + + case AssignOrMap.map => + val (lastGeneratedName, newGeneratedByMetalsVals) = + createNewName(termApply, newMetalsNames) + val newEnumerations = List( + Enumerator.Generator( + Pat.Var.apply(Term.Name.apply(lastGeneratedName)), + nextQual + ) + ) + val newYield = Term.Name.apply(lastGeneratedName) + (newEnumerations, Some(newYield), newGeneratedByMetalsVals) + } + } + qual match { + case qualTermApply: Term.Apply => + extractChainedForYield( + parseTerm, + parse, + Some(valueName), + maybeYieldTerm, + elems, + qualTermApply, + document, + updatedGenByMetalsVals + ) + case otherQual => + ( + Enumerator.Generator( + Pat.Var.apply(Term.Name.apply(valueName)), + otherQual + ) + +: elems, + maybeYieldTerm, + updatedGenByMetalsVals + ) + + } + } + result.getOrElse(List.empty, None, newMetalsNames) + + case termSelect: Term.Select + if termSelect.name.value == "filter" || termSelect.name.value == "filterNot" || + termSelect.name.value == "withFilter" => + val qual = termSelect.qual + val filterOrNot = + if ( + termSelect.name.value == "filter" || termSelect.name.value == "withFilter" + ) FilterOrNot.filter + else FilterOrNot.filterNot + val result = for { + valueName <- perhapseValueNameAndNextQual.map(_._1) + nextCondition <- perhapseValueNameAndNextQual.map(_._2) + } yield { + val (elems, maybeYieldTerm): (List[Enumerator], Option[Term]) = + perhapseLastName match { + case Some(lastName) => + ( + Enumerator.Val.apply( + Pat.Var.apply(Term.Name.apply(lastName)), + Term.Name.apply(valueName) + ) +: existingForElements, + currentYieldTerm + ) + case None => + (existingForElements, Some(Term.Name.apply(valueName))) + } + + qual match { + case qualTermApply: Term.Apply => + extractChainedForYield( + parseTerm, + parse, + Some(valueName), + maybeYieldTerm, + Enumerator.Guard.apply( + filterOrNot match { + case FilterOrNot.filter => nextCondition + case FilterOrNot.filterNot => + Term.ApplyUnary.apply(Term.Name.apply("!"), nextCondition) + } + ) +: elems, + qualTermApply, + document, + newMetalsNames + ) + case otherQual => + ( + Enumerator.Generator( + Pat.Var.apply(Term.Name.apply(valueName)), + otherQual + ) + +: Enumerator.Guard.apply( + filterOrNot match { + case FilterOrNot.filter => nextCondition + case FilterOrNot.filterNot => + Term.ApplyUnary + .apply(Term.Name.apply("!"), nextCondition) + } + ) +: elems, + maybeYieldTerm, + newMetalsNames + ) + } + } + result.getOrElse(List.empty, currentYieldTerm, generatedByMetalsValues) + case _ => + perhapseLastName match { + case Some(lastName) => + ( + Enumerator.Generator( + Pat.Var.apply(Term.Name.apply(lastName)), + termApply + ) + +: existingForElements, + currentYieldTerm, + generatedByMetalsValues + ) + case None => + (existingForElements, Some(termApply), generatedByMetalsValues) + } + } + } + + private def createNewName( + tree: Tree, + generatedByMetalsValues: Set[String] + ): (String, Set[String]) = { + + def findTopMostParent: Tree = { + var initialParent = tree + while (initialParent.parent.isDefined) + initialParent = initialParent.parent.get + + initialParent + } + + // We don't want to use any name that is already being used in the scope + def loop(t: Tree): List[String] = { + t.children.flatMap { + case n: Name => List(n.toString()) + case child => loop(child) + } + } + + val newValuePrefix = "generatedByMetals" + val names = loop(findTopMostParent).toSet ++ generatedByMetalsValues + + if (!names(newValuePrefix)) + (newValuePrefix, generatedByMetalsValues + newValuePrefix) + else { + var i = 0 + while (names(s"$newValuePrefix$i")) + i += 1 + val result = s"$newValuePrefix$i" + (result, generatedByMetalsValues + result) + } + } + + private def getIndentForPos( + treePos: Position, + document: String + ): String = + document + .substring(treePos.start - treePos.startColumn, treePos.start) + .takeWhile(_.isWhitespace) + + private def isTreeInteresting: Tree => Boolean = { + case _: Term.Apply => true + case termSelect: Term.Select + if termSelect.name.value == "map" || termSelect.name.value == "flatMap" || + termSelect.name.value == "filter" || termSelect.name.value == "filterNot" || + termSelect.name.value == "withFilter" => + true + case termName: Term.Name + if termName.value == "flatMap" || termName.value == "map" || + termName.value == "filter" || termName.value == "filterNot" || + termName.value == "withFilter" => + true + case _ => false + } +} + +object FlatMapToForComprehensionCodeAction { + val flatMapToForComprehension = "Turn into for comprehension" +} + +object AssignOrMap extends Enumeration { + type AssignOrMap = Value + val assign: Value = Value("=") + val map: Value = Value("<-") +} + +object FilterOrNot extends Enumeration { + type FilterOrNot = Value + val filter: Value = Value("") + val filterNot: Value = Value("!") +} diff --git a/metals/src/main/scala/scala/meta/internal/parsing/Trees.scala b/metals/src/main/scala/scala/meta/internal/parsing/Trees.scala index d658a1e94b8..793c2d3930a 100644 --- a/metals/src/main/scala/scala/meta/internal/parsing/Trees.scala +++ b/metals/src/main/scala/scala/meta/internal/parsing/Trees.scala @@ -18,7 +18,7 @@ import org.eclipse.{lsp4j => l} * Manages parsing of Scala source files into Scalameta syntax trees. * * - provides the latest good Scalameta tree for a given source file - * similar as `Buffers` provides the current text content. + * similar as `Buffers` provides the current text content. * - produces diagnostics for syntax errors. */ final class Trees( @@ -51,8 +51,8 @@ final class Trees( /** * Find last tree matching T that encloses the position. * - * @param source source to load the tree for - * @param lspPos cursor position + * @param source source to load the tree for + * @param lspPos cursor position * @param predicate predicate which T must fulfill * @return found tree node of type T or None */ @@ -72,6 +72,7 @@ final class Trees( enclosedChildren(other.children, pos).flatMap(loop(_, pos)) } } + get(source).flatMap { tree => val pos = lspPos.toMeta(tree.pos.input) loop(tree, pos) @@ -81,6 +82,7 @@ final class Trees( /** * Parse file at the given path and return a list of errors if there are any. + * * @param path file to parse * @return list of errors if the file failed to parse */ @@ -128,6 +130,32 @@ final class Trees( } } + def parse( + path: AbsolutePath + )(code: String): Parsed[Tree] = { + val dialect = scalaVersionSelector.getDialect(path) + val input = Input.VirtualFile(path.toURI.toString(), code) + if (path.isAmmoniteScript) { + val ammoniteInput = Input.Ammonite(input) + dialect(ammoniteInput).parse(Parse.parseAmmonite) + } else { + dialect(input).parse[Source] + } + } + + def parseTerm( + path: AbsolutePath + )(code: String): Parsed[Term] = { + val dialect = scalaVersionSelector.getDialect(path) + val input = Input.VirtualFile(path.toURI.toString(), code) + if (path.isAmmoniteScript) { + val ammoniteInput = Input.Ammonite(input) + dialect(ammoniteInput).parse[Term](Parse.parseTerm) + } else { + dialect(input).parse[Term] + } + } + } object Trees { diff --git a/mtags/src/main/scala-2.11/scala/meta/internal/jdk/CollectionConverters.scala b/mtags/src/main/scala-2.11/scala/meta/internal/jdk/CollectionConverters.scala index 85f83474e78..06ec087b6ff 100644 --- a/mtags/src/main/scala-2.11/scala/meta/internal/jdk/CollectionConverters.scala +++ b/mtags/src/main/scala-2.11/scala/meta/internal/jdk/CollectionConverters.scala @@ -1,5 +1,6 @@ package scala.meta.internal.jdk -import scala.collection.convert.{DecorateAsJava, DecorateAsScala} +import scala.collection.convert.DecorateAsJava +import scala.collection.convert.DecorateAsScala object CollectionConverters extends DecorateAsJava with DecorateAsScala diff --git a/tests/unit/src/test/scala/tests/codeactions/FlatMapToForComprehensionSuite.scala b/tests/unit/src/test/scala/tests/codeactions/FlatMapToForComprehensionSuite.scala new file mode 100644 index 00000000000..737eb3161b2 --- /dev/null +++ b/tests/unit/src/test/scala/tests/codeactions/FlatMapToForComprehensionSuite.scala @@ -0,0 +1,128 @@ +package tests.codeactions + +import scala.meta.internal.metals.codeactions.FlatMapToForComprehensionCodeAction +import scala.meta.internal.metals.codeactions.RewriteBracesParensCodeAction + +class FlatMapToForComprehensionSuite + extends BaseCodeActionLspSuite("forComprehension") { + + check( + "mixture", + """|object A { + | def double(x : Int, y: Int = 1) = y * x + | def check(x: Int) = true + | val list = List(1, 2, 3) + | + | val res3 = list + | .flatMap{ + | a => + | { + | val m = 6 + | Some(a + 1).map(b => b + 3 + 4) + | } + | + | }.filter (check) + | .filterNot(_ => true) + | .map(_ => 7) + | .map(c => c - 1) + | .filter(d => d > 1) + | .map( double(_, 4).toFloat.toDouble) + | .m<<>>ap( _.toInt.compare(3)) + | + |} + |""".stripMargin, + s"""|${RewriteBracesParensCodeAction.toBraces} + |${FlatMapToForComprehensionCodeAction.flatMapToForComprehension} + |""".stripMargin, + """|object A { + | def double(x : Int, y: Int = 1) = y * x + | def check(x: Int) = true + | val list = List(1, 2, 3) + | + | val res3 = { + | for { + | a <- list + | generatedByMetals3 <- { + | val m = 6 + | Some(a + 1).map(b => b + 3 + 4) + | } + | if (check(generatedByMetals3)) + | generatedByMetals2 = generatedByMetals3 + | if !(true) + | generatedByMetals1 = generatedByMetals2 + | c = 7 + | d = c - 1 + | if (d > 1) + | generatedByMetals0 = d + | generatedByMetals = double(generatedByMetals0, 4).toFloat.toDouble + | } yield { + | generatedByMetals.toInt.compare(3) + | } + | } + | + |} + |""".stripMargin, + selectedActionIndex = 1 + ) + + check( + "flatten-for-flatMap", + """|object A { + | def double(x : Int, y: Int = 1) = y * x + | def check(x: Int) = true + | val list = List(1, 2, 3) + | + | val res3 = list + | .flatMap{ + | a => { + | val m = 6 + | Some(a + 1).map(b => b + 3 + 4) + | } + | + | }.filter (check) + | .filterNot(_ => true) + | .map(_ => 7) + | .map(c => c - 1) + | .filter(d => d > 1) + | .map( double(_, 4).toFloat.toDouble) + | .map( _.toInt.compare(3)) + | .fl<<>>atMap( m => Some(m * 3)) + | + |} + |""".stripMargin, + s"""|${RewriteBracesParensCodeAction.toBraces} + |${FlatMapToForComprehensionCodeAction.flatMapToForComprehension} + |""".stripMargin, + """|object A { + | def double(x : Int, y: Int = 1) = y * x + | def check(x: Int) = true + | val list = List(1, 2, 3) + | + | val res3 = { + | for { + | a <- list + | generatedByMetals3 <- { + | val m = 6 + | Some(a + 1).map(b => b + 3 + 4) + | } + | if (check(generatedByMetals3)) + | generatedByMetals2 = generatedByMetals3 + | if !(true) + | generatedByMetals1 = generatedByMetals2 + | c = 7 + | d = c - 1 + | if (d > 1) + | generatedByMetals0 = d + | generatedByMetals = double(generatedByMetals0, 4).toFloat.toDouble + | m = generatedByMetals.toInt.compare(3) + | generatedByMetals4 <- Some(m * 3) + | } yield { + | generatedByMetals4 + | } + | } + | + |} + |""".stripMargin, + selectedActionIndex = 1 + ) +}