Skip to content

Commit

Permalink
Allow to train pipeline several times
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksei-ai committed Dec 22, 2017
1 parent 0140c30 commit 6f30421
Showing 1 changed file with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ abstract class AnnotatorWithWordEmbeddings[A <: AnnotatorWithWordEmbeddings[A, M

override def beforeTraining(spark: SparkSession): Unit = {
if (isDefined(sourceEmbeddingsPath)) {
indexEmbeddings(localPath, spark.sparkContext)
WordEmbeddingsClusterHelper.copyIndexToCluster(localPath, spark.sparkContext)
// 1. Create tmp file for index
localPath = Some(WordEmbeddingsClusterHelper.createLocalPath())
// 2. Index Word Embeddings
indexEmbeddings(localPath.get, spark.sparkContext)
// 3. Copy WordEmbeddings to cluster
WordEmbeddingsClusterHelper.copyIndexToCluster(localPath.get, spark.sparkContext)
// 4. Create Embeddings for usage during train
embeddings = Some(WordEmbeddings(localPath.get, $(embeddingsNDims)))
}
}

override def onTrained(model: M, spark: SparkSession): Unit = {
if (isDefined(sourceEmbeddingsPath)) {
model.setDims($(embeddingsNDims))
val fileName = WordEmbeddingsClusterHelper.getClusterFileName(localPath.get).toString

val fileName = WordEmbeddingsClusterHelper.getClusterFileName(localPath).toString
model.setDims($(embeddingsNDims))
model.setIndexPath(fileName)
}
}

lazy val embeddings: Option[WordEmbeddings] = {
get(sourceEmbeddingsPath).map(_ => WordEmbeddings(localPath, $(embeddingsNDims)))
}

private lazy val localPath: String = {
WordEmbeddingsClusterHelper.createLocalPath
}
var embeddings: Option[WordEmbeddings] = None
private var localPath: Option[String] = None

private def indexEmbeddings(localFile: String, spark: SparkContext): Unit = {
val formatId = $(embeddingsFormat)
Expand Down

0 comments on commit 6f30421

Please sign in to comment.