diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AnnotatorWithWordEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AnnotatorWithWordEmbeddings.scala index 25588fc59911b8..6b69f24aa62ea3 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AnnotatorWithWordEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AnnotatorWithWordEmbeddings.scala @@ -36,27 +36,28 @@ abstract class AnnotatorWithWordEmbeddings[A <: AnnotatorWithWordEmbeddings[A, M override def beforeTraining(spark: SparkSession): Unit = { if (isDefined(sourceEmbeddingsPath)) { - indexEmbeddings(localPath, spark.sparkContext) - WordEmbeddingsClusterHelper.copyIndexToCluster(localPath, spark.sparkContext) + // 1. Create tmp file for index + localPath = Some(WordEmbeddingsClusterHelper.createLocalPath()) + // 2. Index Word Embeddings + indexEmbeddings(localPath.get, spark.sparkContext) + // 3. Copy WordEmbeddings to cluster + WordEmbeddingsClusterHelper.copyIndexToCluster(localPath.get, spark.sparkContext) + // 4. Create Embeddings for usage during train + embeddings = Some(WordEmbeddings(localPath.get, $(embeddingsNDims))) } } override def onTrained(model: M, spark: SparkSession): Unit = { if (isDefined(sourceEmbeddingsPath)) { - model.setDims($(embeddingsNDims)) + val fileName = WordEmbeddingsClusterHelper.getClusterFileName(localPath.get).toString - val fileName = WordEmbeddingsClusterHelper.getClusterFileName(localPath).toString + model.setDims($(embeddingsNDims)) model.setIndexPath(fileName) } } - lazy val embeddings: Option[WordEmbeddings] = { - get(sourceEmbeddingsPath).map(_ => WordEmbeddings(localPath, $(embeddingsNDims))) - } - - private lazy val localPath: String = { - WordEmbeddingsClusterHelper.createLocalPath - } + var embeddings: Option[WordEmbeddings] = None + private var localPath: Option[String] = None private def indexEmbeddings(localFile: String, spark: SparkContext): Unit = { val formatId = $(embeddingsFormat)