From 861ec48bc74616b47d45ad3b828097a35045050f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 8 Jul 2014 18:09:23 -0700 Subject: [PATCH] simplify axpy --- .../mllib/linalg/distributed/RowMatrix.scala | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 29c0adc51fe2b..67d79109cda00 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -202,32 +202,23 @@ class RowMatrix( } /** - * Multiply the Gramian matrix `A^T A` by a DenseVector on the right. + * Multiplies the Gramian matrix `A^T A` by a dense vector on the right without computing `A^T A`. * - * @param v a local DenseVector whose length must match the number of columns of this matrix. - * @return a local DenseVector representing the product. + @param v a dense vector whose length must match the number of columns of this matrix + * @return a dense vector representing the product */ private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = { val n = numCols().toInt val vbr = rows.context.broadcast(v) - - val bv = rows.aggregate(BDV.zeros[Double](n))( + rows.aggregate(BDV.zeros[Double](n))( seqOp = (U, r) => { val rBrz = r.toBreeze val a = rBrz.dot(vbr.value) - rBrz match { - case _: BDV[_] => brzAxpy(a, rBrz.asInstanceOf[BDV[Double]], U) - case _: BSV[_] => brzAxpy(a, rBrz.asInstanceOf[BSV[Double]], U) - case _ => - throw new UnsupportedOperationException( - s"Do not support vector operation from type ${rBrz.getClass.getName}.") - } + brzAxpy(a, rBrz, U.asInstanceOf[BV[Double]]) U }, combOp = (U1, U2) => U1 += U2 ) - - bv } /**