diff --git a/core/src/main/java/com/facebook/ktfmt/Main.kt b/core/src/main/java/com/facebook/ktfmt/Main.kt index 62681246..f5845835 100644 --- a/core/src/main/java/com/facebook/ktfmt/Main.kt +++ b/core/src/main/java/com/facebook/ktfmt/Main.kt @@ -121,7 +121,10 @@ fun expandArgsToFileNames(args: List): List { if (arg == "-") { error("Error: '-', which causes ktfmt to read from stdin, should not be mixed with file name") } - result.addAll(File(arg).walkTopDown().filter { it.isFile && it.extension == "kt" }) + result.addAll( + File(arg).walkTopDown().filter { + it.isFile && (it.extension == "kt" || it.extension == "kts") + }) } return result } diff --git a/core/src/test/java/com/facebook/ktfmt/MainKtTest.kt b/core/src/test/java/com/facebook/ktfmt/MainKtTest.kt index 72f6d643..c9e3c72e 100644 --- a/core/src/test/java/com/facebook/ktfmt/MainKtTest.kt +++ b/core/src/test/java/com/facebook/ktfmt/MainKtTest.kt @@ -216,4 +216,25 @@ class MainKtTest { assertThat(output.toString("UTF-8")).isEqualTo(code) } + + @Test + fun `expandArgsToFileNames - resolves 'kt' and 'kts' filenames only (recursively)`() { + val f1 = root.resolve("1.kt") + val f2 = root.resolve("2.kt") + val f3 = root.resolve("3") + val f4 = root.resolve("4.dummyext") + val f5 = root.resolve("5.kts") + + val dir = root.resolve("foo") + dir.mkdirs() + val f6 = root.resolve("foo/1.kt") + val f7 = root.resolve("foo/2.kts") + val f8 = root.resolve("foo/3.dummyext") + val files = listOf(f1, f2, f3, f4, f5, f6, f7, f8) + for (f in files) { + f.createNewFile() + } + assertThat(expandArgsToFileNames(files.map { it.toString() })) + .containsExactly(f1, f2, f5, f6, f7) + } }