diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 76ecdf92f26ed..a3681e34a147d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -247,9 +247,34 @@ class LDA private (
     new DistributedLDAModel(state, iterationTimes)
   }
 
-  def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = {
-    val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k)
-    (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
+
+  /**
+   * Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
+   * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
+   *
+   * @param documents  RDD of documents, which are term (word) count vectors paired with IDs.
+   *                   The term count vectors are "bags of words" with a fixed-size vocabulary
+   *                   (where the vocabulary size is the length of the vector).
+   *                   Document IDs must be unique and >= 0.
+   * @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
+   *                    -1 for automatic batchNumber.
+   * @return  Inferred LDA model
+   */
+  def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = {
+    val D = documents.count().toInt
+    val batchSize =
+      if (batchNumber == -1) { // auto mode
+        if (D / 100 > 16384) 16384
+        else if (D / 100 < 4) 4
+        else D / 100
+      }
+      else {
+        require(batchNumber > 0, "batchNumber should be positive or -1")
+        D / batchNumber
+      }
+
+    val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchSize)
+    (0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
     new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
   }
 
@@ -411,28 +436,26 @@ private[clustering] object LDA {
    * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
    */
   private[clustering] class OnlineLDAOptimizer(
-    private val documents: RDD[(Long, Vector)],
-    private val k: Int) extends Serializable{
+      private val documents: RDD[(Long, Vector)],
+      private val k: Int,
+      private val batchSize: Int) extends Serializable{
 
     private val vocabSize = documents.first._2.size
     private val D = documents.count().toInt
-    private val batchSize = if (D / 1000 > 4096) 4096
-                            else if (D / 1000 < 4) 4
-                            else D / 1000
-    val batchNumber = D/batchSize
+    val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
 
-    // Initialize the variational distribution q(beta|lambda)
+    //Initialize the variational distribution q(beta|lambda)
     var lambda = getGammaMatrix(k, vocabSize)               // K * V
     private var Elogbeta = dirichlet_expectation(lambda)    // K * V
     private var expElogbeta = exp(Elogbeta)                 // K * V
 
     private var batchId = 0
     def next(): Unit = {
-      require(batchId < batchNumber)
+      require(batchId < actualBatchNumber)
       // weight of the mini-batch. 1024 down weights early iterations
       val weight = math.pow(1024 + batchId, -0.5)
-      val batch = documents.filter(doc => doc._1 % batchNumber == batchId)
-
+      val batch = documents.sample(true, batchSize.toDouble / D)
+      batch.cache()
       // Given a mini-batch of documents, estimates the parameters gamma controlling the
       // variational distribution over the topic weights for each document in the mini-batch.
       var stat = BDM.zeros[Double](k, vocabSize)