Skip to content

Commit

Permalink
feat: save changes
Browse files Browse the repository at this point in the history
  • Loading branch information
viacheslav-dobrynin committed Feb 11, 2024
1 parent 843ec5c commit eb63da8
Show file tree
Hide file tree
Showing 15 changed files with 111 additions and 33 deletions.
2 changes: 2 additions & 0 deletions src/main/kotlin/ru/itmo/stand/StandApplication.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ TODO:
2. бэнчмарки - https://arxiv.org/pdf/2105.04021.pdf, https://github.com/castorini/anserini
по
search -m NEIGHBOURS -f MS_MARCO /home/user/Desktop/queries.out.tsv
save-in-batch -m NEIGHBOURS --with-id /home/user/Downloads/msmarco/collection.tsv
*/
12 changes: 11 additions & 1 deletion src/main/kotlin/ru/itmo/stand/command/SearchCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import picocli.CommandLine.Command
import picocli.CommandLine.Option
import picocli.CommandLine.Parameters
import ru.itmo.stand.service.DocumentService
import ru.itmo.stand.service.bert.BertEmbeddingCalculator
import ru.itmo.stand.service.bert.TranslatorInput
import ru.itmo.stand.service.model.Format
import ru.itmo.stand.service.model.Format.JUST_QUERY
import ru.itmo.stand.util.measureTimeSeconds
Expand All @@ -16,7 +18,10 @@ import java.io.File
mixinStandardHelpOptions = true,
description = ["Return IDs of documents relevant to the query."],
)
class SearchCommand(private val documentService: DocumentService) : Runnable {
class SearchCommand(
private val documentService: DocumentService,
private val bertEmbeddingCalculator: BertEmbeddingCalculator,
) : Runnable {

@Parameters(
paramLabel = "queries file",
Expand All @@ -29,6 +34,11 @@ class SearchCommand(private val documentService: DocumentService) : Runnable {
private var format: Format = JUST_QUERY

override fun run() {
val inputs = queries.bufferedReader().readLines().map { TranslatorInput.withClsWordIndex(it) }
val aLotOfInputs = (1..10).flatMap { inputs }
repeat(10) {
bertEmbeddingCalculator.calculate(aLotOfInputs, 1000)
}
val latencyInSeconds = measureTimeSeconds {
println(documentService.search(queries, format).ifEmpty { "Documents not found." })
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/ru/itmo/stand/config/NlpConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ class NlpConfig {
}

companion object {
const val ANNOTATORS = "tokenize,pos,lemma"
const val ANNOTATORS = "tokenize"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BertEmbeddingCalculator(
) {

private val predictor by lazy {
bertModelLoader.loadModel(standProperties.app.neighboursAlgorithm.bertModelType).newPredictor()
bertModelLoader.loadModel(standProperties.app.neighboursAlgorithm.bertModelType).newPredictor() // TODO: fix me
}

fun calculate(input: TranslatorInput): FloatArray = predictor.predict(arrayOf(input)).first()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DocumentAnnService(

private fun search(query: String): List<String> {
val queryVector = bertEmbeddingCalculator.calculate(TranslatorInput.withClsWordIndex(query)).toTypedArray()
val results = documentEmbeddingInMemoryRepository.findByVector(queryVector, 10)
val results = documentEmbeddingInMemoryRepository.findExactByVector(queryVector, 10)
return results.map { it.id }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ class DocumentEmbeddingCreator(
private val log = KotlinLogging.logger { }

fun create(documents: Sequence<Document>) {
if (documentEmbeddingRepository.countAll() > 0) {
log.info {
"The embeddings already exist. " +
"The addition of new embeddings is not being executed. " +
"The creation process has been omitted. " +
"If you wish to add new ones, please remove the previous ones."
}
return
}
documents.onEachIndexed { index, _ -> if (index % 10000 == 0) log.info { "Document embeddings created: $index" } }
.chunked(BERT_BATCH_SIZE)
.forEach { chunk ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ru.itmo.stand.service.impl.neighbours.indexing

import io.github.oshai.KotlinLogging
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.runBlocking
import org.springframework.stereotype.Service
import ru.itmo.stand.config.StandProperties
import ru.itmo.stand.service.bert.BertEmbeddingCalculator
Expand All @@ -10,8 +13,10 @@ import ru.itmo.stand.storage.embedding.neighbours.model.ContextualizedEmbedding
import ru.itmo.stand.storage.lucene.model.neighbours.NeighboursDocument
import ru.itmo.stand.storage.lucene.repository.neighbours.DocumentEmbeddingRepository
import ru.itmo.stand.storage.lucene.repository.neighbours.InvertedIndex
import ru.itmo.stand.util.dot
import ru.itmo.stand.util.cos
import ru.itmo.stand.util.processConcurrently
import java.io.File
import java.util.concurrent.ConcurrentHashMap

@Service
class InvertedIndexBuilder(
Expand All @@ -23,14 +28,16 @@ class InvertedIndexBuilder(
) {

private val log = KotlinLogging.logger { }
private val documentEmbeddingCache = HashMap<String, FloatArray>()
private val documentEmbeddingCache = ConcurrentHashMap<String, FloatArray>()

fun index(windowedTokensFile: File) {
fun index(windowedTokensFile: File) = runBlocking(Dispatchers.Default) {
val tokensWithWindows = readTokensWindowsAndDocIds(windowedTokensFile)

tokensWithWindows.onEachIndexed { index, token ->
log.info { "Tokens processed: $index. Current token: ${token.token}. Windows size: ${token.docIdsByWindowPairs.size}" }
}.forEach { tokenWithWindows ->
processConcurrently(
tokensWithWindows.asFlow(),
10,
{ log.info { "Tokens processed: $it" } },
) { tokenWithWindows ->
val (_, docIdsByWindowPairs) = tokenWithWindows
val (windows, docIdsList) = docIdsByWindowPairs.unzip()
embeddingCalculator.calculate(windows, standProperties.app.neighboursAlgorithm.bertWindowBatchSize)
Expand Down Expand Up @@ -73,7 +80,7 @@ class InvertedIndexBuilder(
NeighboursDocument(
tokenWithEmbeddingId = contextualizedEmbedding.tokenWithEmbeddingId,
docId = docId,
score = documentEmbedding.dot(contextualizedEmbedding.embedding),
score = documentEmbedding.cos(contextualizedEmbedding.embedding),
)
}
invertedIndex.saveAll(neighboursDocuments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class VectorIndexBuilder(

suspend fun index(windowedTokensFile: File) {
log.info { "Starting vector indexing" }
log.info { "SKIP" }
return
val tokensWindows = readTokensWindows(windowedTokensFile)

val counter = AtomicInteger(0)
Expand Down Expand Up @@ -81,7 +83,7 @@ class VectorIndexBuilder(

log.info { "${tokenInputs.token} has ${embeddings.size} embeddings" }

val clusterModel = XMeans.fit(embeddings.toDoubleArray(), 8) // TODO: configure this value
val clusterModel = XMeans.fit(embeddings.toDoubleArray(), 4) // TODO: configure this value

log.info { "${tokenInputs.token} got ${clusterModel.k} centroids" }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@ class WindowedTokenCreator(
private val log = KotlinLogging.logger { }

fun create(documents: Sequence<Document>): File {
val windowedTokensFile = File("${standProperties.app.basePath}/indexes/neighbours/windowed-tokens.txt")
if (windowedTokensFile.exists()) {
log.info {
"The windowedTokensFile already exist. " +
"The addition of new tokens is not being executed. " +
"The creation process has been omitted. " +
"If you wish to add new ones, please remove the previous ones."
}
return windowedTokensFile
}
val memoryIndex = constructMemoryIndex(documents)

log.info { "MemoryIndex is constructed. Token number: ${memoryIndex.size}" }
log.info { "Min windows: ${memoryIndex.values.minBy { it.keys.size }.keys.size}" }
log.info { "Max windows: ${memoryIndex.values.maxBy { it.keys.size }.keys.size}" }
log.info { "Mean windows: ${memoryIndex.values.map { it.keys.size }.average()}" }

val windowedTokensFile = File("${standProperties.app.basePath}/indexes/neighbours/windowed-tokens.txt")
.createPath()

writeMemoryIndexToFile(memoryIndex, windowedTokensFile)
writeMemoryIndexToFile(memoryIndex, windowedTokensFile.createPath())

return windowedTokensFile
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
package ru.itmo.stand.service.impl.neighbours.search

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.runBlocking
import org.springframework.stereotype.Service
import ru.itmo.stand.service.bert.BertEmbeddingCalculator
import ru.itmo.stand.service.impl.neighbours.PreprocessingPipelineExecutor
import ru.itmo.stand.storage.embedding.neighbours.ContextualizedEmbeddingRepository
import ru.itmo.stand.storage.lucene.model.neighbours.NeighboursDocument
import ru.itmo.stand.storage.lucene.repository.neighbours.InvertedIndex
import java.util.stream.Collectors.groupingBy

@Service
class NeighboursSearcher(
Expand All @@ -14,17 +21,44 @@ class NeighboursSearcher(
private val invertedIndex: InvertedIndex,
) {

fun search(query: String): List<String> {
/**
* TODO:
* 1. sort by doc id
* 2. use filed selector to load only doc id and score
* 3. try hit collector
* 4. try cache field
* 5. use filter instead of query
* see https://cwiki.apache.org/confluence/display/lucene/ImproveSearchingSpeed#
*/
fun search(query: String): List<String> = runBlocking(Dispatchers.Default) {
val windows = preprocessingPipelineExecutor.execute(query)
val embeddings = bertEmbeddingCalculator.calculate(windows.map { it.toTranslatorInput() }.toTypedArray())
return embeddings.flatMap { embedding -> contextualizedEmbeddingRepository.findByVector(embedding.toTypedArray()) }
.let { contextualizedEmbeddings ->
val tokenWithEmbeddingIds = contextualizedEmbeddings.map { it.tokenWithEmbeddingId }
invertedIndex.findByTokenWithEmbeddingIds(tokenWithEmbeddingIds).groupingBy { it.docId }
.foldTo(HashMap(), 0f) { acc, doc -> acc + doc.score }
}.entries
val documents = mutableSetOf<NeighboursDocument>()
windows.asFlow()
.map { bertEmbeddingCalculator.calculate(it.toTranslatorInput()) }
.buffer()
.map { contextualizedEmbeddingRepository.findByVector(it.toTypedArray()) }
.buffer()
.collect { embs ->
documents += invertedIndex.findByTokenWithEmbeddingIds(embs.map { it.tokenWithEmbeddingId })
}
documents
.groupingBy { it.docId }
.foldTo(HashMap(), 0f) { acc, doc -> acc + doc.score }
.entries
.sortedByDescending { (_, score) -> score }
.take(10) // TODO: configure this value
.map { (docId, _) -> docId }
// val embeddings = bertEmbeddingCalculator.calculate(windows.map { it.toTranslatorInput() }.toTypedArray())
// embeddings.flatMap { embedding -> contextualizedEmbeddingRepository.findByVector(embedding.toTypedArray()) }
// .let { contextualizedEmbeddings ->
// val tokenWithEmbeddingIds = contextualizedEmbeddings.map { it.tokenWithEmbeddingId }
// invertedIndex.findByTokenWithEmbeddingIds(tokenWithEmbeddingIds)
// .toSet()
// .groupingBy { it.docId }
// .foldTo(HashMap(), 0f) { acc, doc -> acc + doc.score }
// }.entries
// .sortedByDescending { (_, score) -> score }
// .take(10) // TODO: configure this value
// .map { (docId, _) -> docId }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DocumentEmbeddingInMemoryRepository(
private val itemSerializer = JavaObjectSerializer<DocumentEmbedding>()
private val indexFile = File("${standProperties.app.basePath}/indexes/ann/hnsw").createPath()

private var index = runCatching {
private val index = runCatching {
HnswIndex.load<String, FloatArray, DocumentEmbedding, Float>(indexFile)
}.getOrElse {
log.info { "Got exception [${it.javaClass.simpleName}] during index loading with message: ${it.message}" }
Expand All @@ -36,6 +36,7 @@ class DocumentEmbeddingInMemoryRepository(
.withEfConstruction(128)
.build()
}
private val exactIndex = index.asExactIndex()

@PreDestroy
fun saveIndex() {
Expand All @@ -45,6 +46,9 @@ class DocumentEmbeddingInMemoryRepository(
fun findByVector(vector: Array<Float>, topN: Int): List<DocumentEmbedding> =
index.findNearest(vector.toFloatArray(), topN).map { it.item() }

fun findExactByVector(vector: Array<Float>, topN: Int): List<DocumentEmbedding> =
exactIndex.findNearest(vector.toFloatArray(), topN).map { it.item() }

fun index(embedding: DocumentEmbedding) {
index.add(embedding)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ContextualizedEmbeddingInMemoryRepository(

override fun findByVector(vector: Array<Float>): List<ContextualizedEmbedding> =
index.findNearest(vector.toFloatArray(), 10) // TODO: configure this value
.filter { it.distance() <= 5 } // TODO: configure this value
.filter { it.distance() <= 4 } // TODO: configure this value
.map { it.item() }

override fun index(embedding: ContextualizedEmbedding) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.apache.lucene.document.StringField
import org.apache.lucene.index.ConcurrentMergeScheduler
import org.apache.lucene.index.IndexWriterConfig
import org.apache.lucene.index.Term
import org.apache.lucene.search.MatchAllDocsQuery
import org.apache.lucene.search.TermQuery
import org.springframework.stereotype.Repository
import ru.itmo.stand.config.StandProperties
Expand Down Expand Up @@ -67,4 +68,6 @@ class DocumentEmbeddingRepository(private val standProperties: StandProperties)
writer.forceMerge(1, true)
writer.commit()
}

fun countAll(): Int = searcher.count(MatchAllDocsQuery())
}
2 changes: 1 addition & 1 deletion src/main/kotlin/ru/itmo/stand/util/Preprocessing.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fun String.toNgrams(minGram: Int = 2, maxGram: Int = 2): List<String> {

fun String.toTokens(stanfordCoreNLP: StanfordCoreNLP): List<String> = stanfordCoreNLP.processToCoreDocument(this)
.tokens()
.map { it.lemma().lowercase() }
.map { it.value().lowercase() }

/**
* For n tokens and size = m,
Expand Down
12 changes: 6 additions & 6 deletions tools/run_stand.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Initialize
method=bm25
repo_path=/your/repo/path/
collection_path=collections/collection.500k.tsv
queries_path=collections/queries.500k.tsv
qrels_path=collections/qrels.500k.tsv
is_need_to_index=true
method=ann
repo_path=/home/user/IdeaProjects/IR-stand/
collection_path=collections/collection.with-preprocessed-ids.500k.tsv
queries_path=collections/queries.preprocessed.500k.tsv
qrels_path=collections/qrels.preprocessed.500k.tsv
is_need_to_index=false

# Rebuild
cd $repo_path || { echo "Failure"; exit 1; }
Expand Down

0 comments on commit eb63da8

Please sign in to comment.