diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index da8f3a24ff27e..11be1d85fbead 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -100,7 +100,7 @@ private[libsvm] class LibSVMFileFormat "though the input. If you know the number in advance, please specify it via " + "'numFeatures' option to avoid the extra scan.") - val paths = files.map(_.getPath.toUri.toString) + val paths = files.map(_.getPath.toString) val parsed = MLUtils.parseLibSVMFile(sparkSession, paths) MLUtils.computeNumFeatures(parsed) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 9198334ba02a1..24113003c9675 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -110,7 +110,8 @@ object MLUtils extends Logging { DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = Map(DataSource.GLOB_PATHS_KEY -> "false") ).resolveRelation(checkFilesExist = false)) .select("value") diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 263ad26657545..0999892364e2c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -191,4 +191,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { spark.sql("DROP TABLE IF EXISTS libsvmTable") } } + + test("SPARK-32815: Test LibSVM data source on file paths with glob metacharacters") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + // test libsvm writer / reader without specifying schema + val svmFileName = "[abc]" + val escapedSvmFileName = "\\[abc\\]" + val rawData = new java.util.ArrayList[Row]() + rawData.add(Row(1.0, Vectors.sparse(2, Seq((0, 2.0), (1, 3.0))))) + val struct = new StructType() + .add("labelFoo", DoubleType, false) + .add("featuresBar", VectorType, false) + val df = spark.createDataFrame(rawData, struct) + df.write.format("libsvm").save(s"$basePath/$svmFileName") + val df2 = spark.read.format("libsvm").load(s"$basePath/$escapedSvmFileName") + val row1 = df2.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(2, Seq((0, 2.0), (1, 3.0)))) + } + } }