Skip to content

Commit

Permalink
Fix mode/comp params so parameter overrides work (opensearch-project#…
Browse files Browse the repository at this point in the history
…2083)

PR adds capability to override parameters when specifying mode and
compression. In order to do this, I add functionality for creating a
deep copy of KNNMethodContext and MethodComponentContext so that we
wouldnt overwrite user provided config. Then, re-arranged some of the
parameter resolution logic.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored Sep 10, 2024
1 parent 524dbd0 commit 270ac6a
Show file tree
Hide file tree
Showing 56 changed files with 2,715 additions and 479 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Abstract {@link MethodResolver} with helpful utilitiy functions that can be shared across different
* implementations
*/
public abstract class AbstractMethodResolver implements MethodResolver {

/**
* Utility method to get the compression level from the context
*
* @param resolvedKnnMethodContext Resolved method context. Should have an encoder set in the params if available
* @return {@link CompressionLevel} Compression level that is configured with the {@link KNNMethodContext}
*/
protected CompressionLevel resolveCompressionLevelFromMethodContext(
KNNMethodContext resolvedKnnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
Map<String, Encoder> encoderMap
) {
// If the context is null, the compression is not configured or the encoder is not defined, return not configured
// because the method context does not contain this info
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return CompressionLevel.x1;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
return CompressionLevel.NOT_CONFIGURED;
}
return encoder.calculateCompressionLevel(getEncoderComponentContext(resolvedKnnMethodContext), knnMethodConfigContext);
}

protected void resolveMethodParams(
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext,
MethodComponent methodComponent
) {
Map<String, Object> resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded(
methodComponentContext,
methodComponent,
knnMethodConfigContext
);
methodComponentContext.getParameters().putAll(resolvedParams);
}

protected KNNMethodContext initResolvedKNNMethodContext(
KNNMethodContext originalMethodContext,
KNNEngine knnEngine,
SpaceType spaceType,
String methodName
) {
if (originalMethodContext == null) {
return new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(methodName, new HashMap<>()));
}
return new KNNMethodContext(originalMethodContext);
}

protected String getEncoderName(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext);
if (methodComponentContext == null) {
return null;
}

return methodComponentContext.getName();
}

protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER);
}

/**
* Determine if the encoder parameter is specified
*
* @param knnMethodContext {@link KNNMethodContext}
* @return true is the encoder is specified in the structure; false otherwise
*/
protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) {
return knnMethodContext != null
&& knnMethodContext.getMethodComponentContext().getParameters() != null
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER);
}

protected boolean shouldEncoderBeResolved(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
// The encoder should not be resolved if:
// 1. The encoder is specified
// 2. The compression is x1
// 3. The compression is not specified and the mode is not disk-based
if (isEncoderSpecified(knnMethodContext)) {
return false;
}

if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1) {
return false;
}

if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel()) == false
&& Mode.ON_DISK != knnMethodConfigContext.getMode()) {
return false;
}

if (VectorDataType.FLOAT != knnMethodConfigContext.getVectorDataType()) {
return false;
}

return true;
}

protected ValidationException validateNotTrainingContext(
boolean shouldRequireTraining,
KNNEngine knnEngine,
ValidationException validationException
) {
if (shouldRequireTraining) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "Cannot use \"%s\" engine from training context", knnEngine.getName())
);
}

return validationException;
}

protected ValidationException validateCompressionSupported(
CompressionLevel compressionLevel,
Set<CompressionLevel> supportedCompressionLevels,
KNNEngine knnEngine,
ValidationException validationException
) {
if (CompressionLevel.isConfigured(compressionLevel) && supportedCompressionLevels.contains(compressionLevel) == false) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "\"%s\" does not support \"%s\" compression", knnEngine.getName(), compressionLevel.getName())
);
}
return validationException;
}

protected ValidationException validateCompressionNotx1WhenOnDisk(
KNNMethodConfigContext knnMethodConfigContext,
ValidationException validationException
) {
if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1 && knnMethodConfigContext.getMode() == Mode.ON_DISK) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "Cannot specify \"x1\" compression level when using \"%s\" mode", Mode.ON_DISK.getName())
);
}
return validationException;
}

protected void validateCompressionConflicts(CompressionLevel originalCompressionLevel, CompressionLevel resolvedCompressionLevel) {
if (CompressionLevel.isConfigured(originalCompressionLevel)
&& CompressionLevel.isConfigured(resolvedCompressionLevel)
&& resolvedCompressionLevel != originalCompressionLevel) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("Cannot specify an encoder that conflicts with the provided compression level");
throw validationException;
}
}
}
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.mapper.CompressionLevel;

/**
* Interface representing an encoder. An encoder generally refers to a vector quantizer.
*/
Expand All @@ -24,4 +26,14 @@ default String getName() {
* @return Method component associated with the encoder
*/
MethodComponent getMethodComponent();

/**
* Calculate the compression level for the give params. Assume float32 vectors are used. All parameters should
* be resolved in the encoderContext passed in.
*
* @param encoderContext Context for the encoder to extract params from
* @return Compression level this encoder produces. If the encoder does not support this calculation yet, it will
* return {@link CompressionLevel#NOT_CONFIGURED}
*/
CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext);
}
62 changes: 62 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/EngineResolver.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

/**
* Figures out what {@link KNNEngine} to use based on configuration details
*/
public final class EngineResolver {

public static final EngineResolver INSTANCE = new EngineResolver();

private EngineResolver() {}

/**
* Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}.
*
* @param knnMethodConfigContext configuration context
* @param knnMethodContext KNNMethodContext
* @param requiresTraining whether config requires training
* @return {@link KNNEngine}
*/
public KNNEngine resolveEngine(
KNNMethodConfigContext knnMethodConfigContext,
KNNMethodContext knnMethodContext,
boolean requiresTraining
) {
// User configuration gets precedence
if (knnMethodContext != null && knnMethodContext.isEngineConfigured()) {
return knnMethodContext.getKnnEngine();
}

// Faiss is the only engine that supports training, so we default to faiss here for now
if (requiresTraining) {
return KNNEngine.FAISS;
}

Mode mode = knnMethodConfigContext.getMode();
CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel();
// If both mode and compression are not specified, we can just default
if (Mode.isConfigured(mode) == false && CompressionLevel.isConfigured(compressionLevel) == false) {
return KNNEngine.DEFAULT;
}

// For 1x, we need to default to faiss if mode is provided and use nmslib otherwise
if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) {
return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.DEFAULT;
}

// Lucene is only engine that supports 4x - so we have to default to it here.
if (compressionLevel == CompressionLevel.x4) {
return KNNEngine.LUCENE;
}

return KNNEngine.FAISS;
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,14 @@ public void setInitialized(Boolean isInitialized) {
public List<String> mmapFileExtensions() {
return knnLibrary.mmapFileExtensions();
}

@Override
public ResolvedMethodContext resolveMethod(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
boolean shouldRequireTraining,
final SpaceType spaceType
) {
return knnLibrary.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/**
* KNNLibrary is an interface that helps the plugin communicate with k-NN libraries
*/
public interface KNNLibrary {
public interface KNNLibrary extends MethodResolver {

/**
* Gets the version of the library that is being used. In general, this can be used for ensuring compatibility of
Expand Down
Loading

0 comments on commit 270ac6a

Please sign in to comment.