Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark-7422][MLLIB] Add argmax to Vector, SparseVector #6112

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
04677af
initial work on adding argmax to Vector and SparseVector
May 11, 2015
3cffed4
Adding unit tests for argmax functions for Dense and Sparse vectors
May 12, 2015
df9538a
Added argmax to sparse vector and added unit test
May 12, 2015
4526acc
Merge branch 'master' of github.com:apache/spark into SPARK-7422
May 13, 2015
eeda560
Fixing SparseVector argmax function to ignore zero values while doing…
May 15, 2015
af17981
Initial work fixing bug that was made clear in pr
dittmarg May 22, 2015
f21dcce
commit
GeorgeDittmar May 25, 2015
b1f059f
Added comment before we start arg max calculation. Updated unit tests…
GeorgeDittmar May 29, 2015
3ee8711
Fixing corner case issue with zeros in the active values of the spars…
GeorgeDittmar Jun 1, 2015
ee1a85a
Cleaning up unit tests a bit and modifying a few cases
GeorgeDittmar Jun 1, 2015
d5b5423
Fixing code style and updating if logic on when to check for zero values
GeorgeDittmar Jun 9, 2015
ac53c55
changing dense vector argmax unit test to be one line call vs 2
GeorgeDittmar Jun 9, 2015
aa330e3
Fixing some last if else spacing issues
GeorgeDittmar Jun 9, 2015
f2eba2f
Cleaning up unit tests to be fewer lines
GeorgeDittmar Jun 9, 2015
b22af46
Fixing spaces between commas in unit test
GeorgeDittmar Jun 10, 2015
42341fb
refactoring arg max check to better handle zero values
GeorgeDittmar Jul 9, 2015
5fd9380
fixing style check error
GeorgeDittmar Jul 9, 2015
98058f4
Merge branch 'master' of github.com:apache/spark into SPARK-7422
GeorgeDittmar Jul 15, 2015
2ea6a55
Added MimaExcludes for Vectors.argmax
GeorgeDittmar Jul 15, 2015
127dec5
update argmax impl
mengxr Jul 17, 2015
3e0a939
Merge pull request #1 from mengxr/SPARK-7422
GeorgeDittmar Jul 18, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
toDense
}
}

/**
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
* Returns -1 if vector has length 0.
*/
def argmax: Int
}

/**
Expand Down Expand Up @@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
new SparseVector(size, ii, vv)
}

/**
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
* Returns -1 if vector has length 0.
*/
private[spark] def argmax: Int = {
override def argmax: Int = {
if (size == 0) {
-1
} else {
Expand Down Expand Up @@ -717,6 +719,51 @@ class SparseVector(
new SparseVector(size, ii, vv)
}
}

override def argmax: Int = {
if (size == 0) {
-1
} else {
// Find the max active entry.
var maxIdx = indices(0)
var maxValue = values(0)
var maxJ = 0
var j = 1
val na = numActives
while (j < na) {
val v = values(j)
if (v > maxValue) {
maxValue = v
maxIdx = indices(j)
maxJ = j
}
j += 1
}

// If the max active entry is nonpositive and there exists inactive ones, find the first zero.
if (maxValue <= 0.0 && na < size) {
if (maxValue == 0.0) {
// If there exists an inactive entry before maxIdx, find it and return its index.
if (maxJ < maxIdx) {
var k = 0
while (k < maxJ && indices(k) == k) {
k += 1
}
maxIdx = k
}
} else {
// If the max active value is negative, find and return the first inactive index.
var k = 0
while (k < na && indices(k) == k) {
k += 1
}
maxIdx = k
}
}

maxIdx
}
}
}

object SparseVector {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging {
assert(vec.toArray.eq(arr))
}

test("dense argmax") {
val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
assert(vec.argmax === -1)

val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
assert(vec2.argmax === 3)

val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
assert(vec3.argmax === 3)
}

test("sparse to array") {
val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
assert(vec.toArray === arr)
}

test("sparse argmax") {
val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
assert(vec.argmax === -1)

val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
assert(vec2.argmax === 3)

val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
assert(vec3.argmax === 2)

// check for case that sparse vector is created with
// only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
assert(vec4.argmax === 0)

val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
assert(vec5.argmax === 1)

val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
assert(vec6.argmax === 2)

val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
assert(vec7.argmax === 1)

val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
assert(vec8.argmax === 0)
}

test("vector equals") {
val dv1 = Vectors.dense(arr.clone())
val dv2 = Vectors.dense(arr.clone())
Expand Down
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ object MimaExcludes {
"org.apache.spark.api.r.StringRRDD.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.r.BaseRRDD.this")
) ++ Seq(
// SPARK-7422 add argmax for sparse vectors
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.argmax")
)

case v if v.startsWith("1.4") =>
Expand Down