-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spark-18408][ML] API Improvements for LSH #15874
Changes from 18 commits
559c099
517a97b
b546dbd
a3cd928
c8243c7
6aac8b3
9870743
0e9250b
adbbefe
c115ed3
033ae5d
c597f4c
d759875
596eb06
00d08bf
3d0810f
257ef19
2c264b7
36ca278
4508393
939e9d5
8b9403d
f0ebcb7
e198080
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,28 +33,28 @@ import org.apache.spark.sql.types._ | |
*/ | ||
private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | ||
/** | ||
* Param for the dimension of LSH OR-amplification. | ||
* Param for the number of hash tables used in LSH OR-amplification. | ||
* | ||
* In this implementation, we use LSH OR-amplification to reduce the false negative rate. The | ||
* higher the dimension is, the lower the false negative rate. | ||
* LSH OR-amplification can be used to reduce the false negative rate. Higher values for this | ||
* param lead to a reduced false negative rate, at the expense of added computational complexity. | ||
* @group param | ||
*/ | ||
final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + | ||
" increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + | ||
" improves the running performance", ParamValidators.gt(0)) | ||
final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " + | ||
"tables, where increasing number of hash tables lowers the false negative rate, and " + | ||
"decreasing it improves the running performance", ParamValidators.gt(0)) | ||
|
||
/** @group getParam */ | ||
final def getOutputDim: Int = $(outputDim) | ||
final def getNumHashTables: Int = $(numHashTables) | ||
|
||
setDefault(outputDim -> 1) | ||
setDefault(numHashTables -> 1) | ||
|
||
/** | ||
* Transform the Schema for LSH | ||
* @param schema The schema of the input dataset without [[outputCol]] | ||
* @return A derived schema with [[outputCol]] added | ||
*/ | ||
protected[this] final def validateAndTransformSchema(schema: StructType): StructType = { | ||
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) | ||
SchemaUtils.appendColumn(schema, $(outputCol), DataTypes.createArrayType(new VectorUDT)) | ||
} | ||
} | ||
|
||
|
@@ -66,10 +66,10 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
self: T => | ||
|
||
/** | ||
* The hash function of LSH, mapping a predefined KeyType to a Vector | ||
* The hash function of LSH, mapping an input feature vector to multiple hash vectors. | ||
* @return The mapping of LSH function. | ||
*/ | ||
protected[ml] val hashFunction: Vector => Vector | ||
protected[ml] val hashFunction: Vector => Array[Vector] | ||
|
||
/** | ||
* Calculate the distance between two different keys using the distance metric corresponding | ||
|
@@ -87,41 +87,24 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
* @param y Another hash vector | ||
* @return The distance between hash vectors x and y | ||
*/ | ||
protected[ml] def hashDistance(x: Vector, y: Vector): Double | ||
protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
transformSchema(dataset.schema, logging = true) | ||
val transformUDF = udf(hashFunction, new VectorUDT) | ||
val transformUDF = udf(hashFunction, DataTypes.createArrayType(new VectorUDT)) | ||
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) | ||
} | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
validateAndTransformSchema(schema) | ||
} | ||
|
||
/** | ||
* Given a large dataset and an item, approximately find at most k items which have the closest | ||
* distance to the item. If the [[outputCol]] is missing, the method will transform the data; if | ||
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the | ||
* transformed data when necessary. | ||
* | ||
* This method implements two ways of fetching k nearest neighbors: | ||
* - Single Probing: Fast, return at most k elements (Probing only one buckets) | ||
* - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key) | ||
* | ||
* @param dataset the dataset to search for nearest neighbors of the key | ||
* @param key Feature vector representing the item to search for | ||
* @param numNearestNeighbors The maximum number of nearest neighbors | ||
* @param singleProbing True for using Single Probing; false for multiple probing | ||
* @param distCol Output column for storing the distance between each result row and the key | ||
* @return A dataset containing at most k items closest to the key. A distCol is added to show | ||
* the distance between each row and the key. | ||
*/ | ||
def approxNearestNeighbors( | ||
// TODO: Fix the MultiProbe NN Search in SPARK-18454 | ||
private[feature] def approxNearestNeighbors( | ||
dataset: Dataset[_], | ||
key: Vector, | ||
numNearestNeighbors: Int, | ||
singleProbing: Boolean, | ||
singleProbe: Boolean, | ||
distCol: String): Dataset[_] = { | ||
require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1") | ||
// Get Hash Value of the key | ||
|
@@ -132,14 +115,24 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
dataset.toDF() | ||
} | ||
|
||
// In the origin dataset, find the hash value that is closest to the key | ||
val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType) | ||
val hashDistCol = hashDistUDF(col($(outputCol))) | ||
val modelSubset = if (singleProbe) { | ||
def sameBucket(x: Seq[Vector], y: Seq[Vector]): Boolean = { | ||
x.zip(y).exists(tuple => tuple._1 == tuple._2) | ||
} | ||
|
||
// In the origin dataset, find the hash value that hash the same bucket with the key | ||
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => | ||
sameBucket(x, keyHash), DataTypes.BooleanType) | ||
|
||
val modelSubset = if (singleProbing) { | ||
modelDataset.filter(hashDistCol === 0.0) | ||
modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol)))) | ||
} else { | ||
// In the origin dataset, find the hash value that is closest to the key | ||
// Limit the use of hashDist since it's controversial | ||
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType) | ||
val hashDistCol = hashDistUDF(col($(outputCol))) | ||
|
||
// Compute threshold to get exact k elements. | ||
// TODO: SPARK-18409: Use approxQuantile to get the threshold | ||
val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors) | ||
val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol)) | ||
val hashThreshold = thresholdDataset.take(1).head.getDouble(0) | ||
|
@@ -155,8 +148,30 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
} | ||
|
||
/** | ||
* Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search | ||
* nearest neighbors and "distCol" as default distCol. | ||
* Given a large dataset and an item, approximately find at most k items which have the closest | ||
* distance to the item. If the [[outputCol]] is missing, the method will transform the data; if | ||
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the | ||
* transformed data when necessary. | ||
* | ||
* NOTE: This method is experimental and will likely change behavior in the next release. | ||
* | ||
* @param dataset the dataset to search for nearest neighbors of the key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Capitalize first words and add periods to all fields There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
* @param key Feature vector representing the item to search for | ||
* @param numNearestNeighbors The maximum number of nearest neighbors | ||
* @param distCol Output column for storing the distance between each result row and the key | ||
* @return A dataset containing at most k items closest to the key. A distCol is added to show | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
* the distance between each row and the key. | ||
*/ | ||
def approxNearestNeighbors( | ||
dataset: Dataset[_], | ||
key: Vector, | ||
numNearestNeighbors: Int, | ||
distCol: String): Dataset[_] = { | ||
approxNearestNeighbors(dataset, key, numNearestNeighbors, true, distCol) | ||
} | ||
|
||
/** | ||
* Overloaded method for approxNearestNeighbors. Use "distCol" as default distCol. | ||
*/ | ||
def approxNearestNeighbors( | ||
dataset: Dataset[_], | ||
|
@@ -179,16 +194,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
inputName: String, | ||
explodeCols: Seq[String]): Dataset[_] = { | ||
require(explodeCols.size == 2, "explodeCols must be two strings.") | ||
val vectorToMap = udf((x: Vector) => x.asBreeze.iterator.toMap, | ||
MapType(DataTypes.IntegerType, DataTypes.DoubleType)) | ||
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { | ||
transform(dataset) | ||
} else { | ||
dataset.toDF() | ||
} | ||
modelDataset.select( | ||
struct(col("*")).as(inputName), | ||
explode(vectorToMap(col($(outputCol)))).as(explodeCols)) | ||
struct(col("*")).as(inputName), posexplode(col($(outputCol))).as(explodeCols)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well here's a fun one. When I run this test: test("memory leak test") {
val numDim = 50
val data = {
for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2))
yield Vectors.sparse(numDim, Seq((i, j.toDouble)))
}
val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
// Project from 100 dimensional Euclidean Space to 10 dimensions
val brp = new BucketedRandomProjectionLSH()
.setNumHashTables(10)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(2.5)
.setSeed(12345)
val model = brp.fit(df)
val joined = model.approxSimilarityJoin(df, df, Double.MaxValue, "distCol")
joined.show()
} I get the following error:
Could you run the same test and see if you get an error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not get the same error, and the result shows successfully. Could you provide me with the full stack of the Exception? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I still get it. Did you use the code above? It's not directly copy pasted from the existing tests.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I copied your code to
Let me see if the test can pass jenkins or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you look at line 292 of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #15916 |
||
} | ||
|
||
/** | ||
|
@@ -293,7 +305,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]] | |
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
/** @group setParam */ | ||
def setOutputDim(value: Int): this.type = set(outputDim, value) | ||
def setNumHashTables(value: Int): this.type = set(numHashTables, value) | ||
|
||
/** | ||
* Validate and create a new instance of concrete LSHModel. Because different LSHModel may have | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: use
@note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.