Skip to content

Commit

Permalink
Merge pull request #342 from lwronski/master
Browse files Browse the repository at this point in the history
Escape backticks in bash/zsh completion
  • Loading branch information
alexarchambault authored Nov 26, 2021
2 parents c8abc96 + 953183d commit 89bca5a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 25 deletions.
4 changes: 3 additions & 1 deletion annotations/shared/src/main/scala/caseapp/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ object ValueDescription {
}

/** Help message for the annotated argument
* @messageMd
* not used by case-app itself, only there as a convenience for case-app users
*/
final case class HelpMessage(message: String) extends StaticAnnotation
final case class HelpMessage(message: String, messageMd: String = "") extends StaticAnnotation

/** Name for the annotated case class of arguments E.g. MyApp
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ object Bash {
}

private def escape(s: String): String =
s.replace("\"", "\\\"")
s.replace("\"", "\\\"").replace("`", "\\`").linesIterator.toStream.headOption.getOrElse("")
def print(items: Seq[CompletionItem]): String = {
val newLine = System.lineSeparator()
val b = new StringBuilder
Expand Down
30 changes: 10 additions & 20 deletions core/shared/src/main/scala/caseapp/core/complete/Zsh.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
package caseapp.core.complete

import java.math.BigInteger
import java.nio.charset.StandardCharsets
import java.security.MessageDigest

import dataclass.data

import scala.collection.mutable
import scala.util.hashing.MurmurHash3

object Zsh {

Expand All @@ -24,32 +18,28 @@ object Zsh {
|}
|""".stripMargin

private def md5(content: Iterator[String]): String = {
val md = MessageDigest.getInstance("MD5")
for (s <- content) md.update(s.getBytes(StandardCharsets.UTF_8))
val digest = md.digest()
val res = new BigInteger(1, digest).toString(16)
if (res.length < 32)
("0" * (32 - res.length)) + res
else
res
private def hash(content: Iterator[String]): String = {
val hash = MurmurHash3.arrayHash(content.toArray)
if (hash < 0) (hash * -1).toString
else hash.toString
}

private def escape(s: String): String =
s.replace("'", "\\'").replace("`", "\\`").linesIterator.toStream.headOption.getOrElse("")
private def defs(item: CompletionItem): Seq[String] = {
val (options, arguments) = item.values.partition(_.startsWith("-"))
val optionsOutput =
if (options.isEmpty) Nil
else {
val escapedOptions = options
val desc = item.description.map(":" + _.replace("'", "\\'")).getOrElse("")
val desc = item.description.map(desc => ":" + escape(desc)).getOrElse("")
options.map { opt =>
"\"" + opt + desc + "\""
}
}
val argumentsOutput =
if (arguments.isEmpty) Nil
else {
val desc = item.description.map(":" + _.replace("'", "\\'")).getOrElse("")
val desc = item.description.map(desc => ":" + escape(desc)).getOrElse("")
arguments.map("'" + _.replace(":", "\\:") + desc + "'")
}
optionsOutput ++ argumentsOutput
Expand All @@ -58,7 +48,7 @@ object Zsh {
private def render(commands: Seq[String]): String =
if (commands.isEmpty) "_files" + System.lineSeparator()
else {
val id = md5(commands.iterator)
val id = hash(commands.iterator)
s"""local -a args$id
|args$id=(
|${commands.mkString(System.lineSeparator())}
Expand Down
2 changes: 1 addition & 1 deletion project/Mima.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import scala.sys.process._
object Mima {

def binaryCompatibilityVersions: Set[String] =
Seq("git", "tag", "--merged", "HEAD^", "--contains", "706a1d90cca205b69e4cff583abb9411526e7d58")
Seq("git", "tag", "--merged", "HEAD^", "--contains", "c8abc969f219357022ea8cf816b4e7653833c620")
.!!
.linesIterator
.map(_.trim)
Expand Down
13 changes: 12 additions & 1 deletion tests/shared/src/test/scala/caseapp/CompletionDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,29 @@ object CompletionDefinitions {
@Name("g") @HelpMessage("A pattern") glob: String = "",
@Name("d") count: Int = 0
)
case class BackTickOptions(
@HelpMessage(
"""A pattern with backtick `--`
|with multiline""".stripMargin
) backtick: String = "",
@Name("d") count: Int = 0
)
object First extends Command[FirstOptions] {
def run(options: FirstOptions, args: RemainingArgs): Unit = ???
}
object Second extends Command[SecondOptions] {
def run(options: SecondOptions, args: RemainingArgs): Unit = ???
}
object BackTick extends Command[BackTickOptions] {
def run(options: BackTickOptions, args: RemainingArgs): Unit = ???
}

object Prog extends CommandsEntryPoint {
def progName = "prog"
def commands = Seq(
First,
Second
Second,
BackTick
)
}
}
Expand Down
30 changes: 29 additions & 1 deletion tests/shared/src/test/scala/caseapp/CompletionTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package caseapp

import caseapp.core.complete.CompletionItem
import caseapp.core.complete.{Bash, CompletionItem, Zsh}
import utest._

object CompletionTests extends TestSuite {
Expand Down Expand Up @@ -110,6 +110,7 @@ object CompletionTests extends TestSuite {
test {
val res = Prog.complete(Seq(""), 0)
val expected = List(
CompletionItem("back-tick", None, Nil),
CompletionItem("first", None, Nil),
CompletionItem("second", None, Nil)
)
Expand Down Expand Up @@ -140,6 +141,33 @@ object CompletionTests extends TestSuite {
)
assert(res == expected)
}

test("bash") {
val res = Prog.complete(Seq("back-tick", "-"), 1)
val expected = List(
CompletionItem("--backtick", Some("A pattern with backtick `--`\nwith multiline"), Nil),
CompletionItem("--count", None, List("-d"))
)
assert(res == expected)

val compRely = Bash.print(res)
val expectedCompRely = """"--backtick -- A pattern with backtick \`--\`"""".stripMargin

assert(compRely.contains(expectedCompRely))
}
test("zsh") {
val res = Prog.complete(Seq("back-tick", "-"), 1)
val expected = List(
CompletionItem("--backtick", Some("A pattern with backtick `--`\nwith multiline"), Nil),
CompletionItem("--count", None, List("-d"))
)
assert(res == expected)

val compRely = Zsh.print(res)
val expectedCompRely = """"--backtick:A pattern with backtick \`--\`"""".stripMargin

assert(compRely.contains(expectedCompRely))
}
}

test("commands with default") {
Expand Down

0 comments on commit 89bca5a

Please sign in to comment.