Skip to content

Commit

Permalink
Add model version to model metadata and change model metadata reads t…
Browse files Browse the repository at this point in the history
…o be from cluster metadata (opensearch-project#2005)

* Add model version to model metadata

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add model version to model metadata and change model metadata reads to be from cluster metadata

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add changelog entry

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Set version from config context

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix spotless

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Update model index mappings

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Change field mapper to read model version

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix tests

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* remove println

Signed-off-by: John Mazanec <jmazane@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
Signed-off-by: John Mazanec <jmazane@amazon.com>
Co-authored-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
ryanbogan and jmazanec15 authored Sep 5, 2024
1 parent da854c9 commit 6814c8f
Show file tree
Hide file tree
Showing 27 changed files with 484 additions and 319 deletions.
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Compatible with OpenSearch 2.17.0
* Add spaceType as a top level optional parameter while creating vector field. [#2044](https://github.com/opensearch-project/k-NN/pull/2044)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
* Add model version to model metadata and change model metadata reads to be from cluster metadata [#2005](https://github.com/opensearch-project/k-NN/pull/2005)
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public class KNNConstants {
public static final String TOP_LEVEL_SPACE_TYPE_FEATURE = "top_level_space_type_feature";

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String MODEL_VERSION = "model_version";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";

// Lucene specific constants
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata
return KNNMethodConfigContext.builder()
.vectorDataType(modelMetadata.getVectorDataType())
.dimension(modelMetadata.getDimension())
.versionCreated(Version.V_2_14_0)
.versionCreated(modelMetadata.getModelVersion())
.mode(modelMetadata.getMode())
.compressionLevel(modelMetadata.getCompressionLevel())
.build();
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class IndexUtil {
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_MODE_AND_COMPRESSION_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION = Version.V_2_17_0;
// public so neural search can access it
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();
public static final Set<VectorDataType> VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE);
Expand Down Expand Up @@ -392,6 +393,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE);
put(KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE, MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE);
put(KNNConstants.MODEL_VERSION, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION);
}
};

Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
if (CompressionLevel.isConfigured(modelMetadata.getCompressionLevel())) {
put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, modelMetadata.getCompressionLevel().getName());
}
put(KNNConstants.MODEL_VERSION, modelMetadata.getModelVersion());
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
Expand Down
56 changes: 46 additions & 10 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.Version;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -59,14 +60,14 @@ public class ModelMetadata implements Writeable, ToXContentObject {
private String error;
@Getter
private final CompressionLevel compressionLevel;
private final Version version;

/**
* Constructor
*
* @param in Stream input
*/
public ModelMetadata(StreamInput in) throws IOException {
String tempTrainingNodeAssignment;
this.knnEngine = KNNEngine.getEngine(in.readString());
this.spaceType = SpaceType.getSpace(in.readString());
this.dimension = in.readInt();
Expand Down Expand Up @@ -96,7 +97,6 @@ public ModelMetadata(StreamInput in) throws IOException {
} else {
this.vectorDataType = VectorDataType.DEFAULT;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) {
this.mode = Mode.fromName(in.readOptionalString());
this.compressionLevel = CompressionLevel.fromName(in.readOptionalString());
Expand All @@ -105,6 +105,11 @@ public ModelMetadata(StreamInput in) throws IOException {
this.compressionLevel = CompressionLevel.NOT_CONFIGURED;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VERSION)) {
this.version = Version.fromString(in.readString());
} else {
this.version = Version.V_EMPTY;
}
}

/**
Expand Down Expand Up @@ -133,7 +138,8 @@ public ModelMetadata(
MethodComponentContext methodComponentContext,
VectorDataType vectorDataType,
Mode mode,
CompressionLevel compressionLevel
CompressionLevel compressionLevel,
Version version
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -159,6 +165,7 @@ public ModelMetadata(
this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null");
this.mode = Objects.requireNonNull(mode, "Mode must not be null");
this.compressionLevel = Objects.requireNonNull(compressionLevel, "Compression level must not be null");
this.version = Objects.requireNonNull(version, "model version must not be null");
}

/**
Expand Down Expand Up @@ -246,6 +253,14 @@ public VectorDataType getVectorDataType() {
return vectorDataType;
}

/**
* Getter for the model version
* @return version
*/
public Version getModelVersion() {
return version;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -279,7 +294,8 @@ public String toString() {
methodComponentContext.toClusterStateString(),
vectorDataType.getValue(),
mode.getName(),
compressionLevel.getName()
compressionLevel.getName(),
version.toString()
);
}

Expand Down Expand Up @@ -317,6 +333,7 @@ public int hashCode() {
.append(getVectorDataType())
.append(getMode())
.append(getCompressionLevel())
.append(getModelVersion())
.toHashCode();
}

Expand All @@ -329,15 +346,15 @@ public int hashCode() {
public static ModelMetadata fromString(String modelMetadataString) {
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);
int length = modelMetadataArray.length;

if (length < 7 || length > 12) {
if (length < 7 || length > 13) {
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>\". or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>,<Version>\"."
);
}

Expand All @@ -357,6 +374,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
CompressionLevel compressionLevel = length > 11
? CompressionLevel.fromName(modelMetadataArray[11])
: CompressionLevel.NOT_CONFIGURED;
Version version = length > 12 ? Version.fromString(modelMetadataArray[12]) : Version.V_EMPTY;

log.debug(getLogMessage(length));

Expand All @@ -372,7 +390,8 @@ public static ModelMetadata fromString(String modelMetadataString) {
methodComponentContext,
vectorDataType,
mode,
compressionLevel
compressionLevel,
version
);
}

Expand All @@ -386,9 +405,10 @@ private static String getLogMessage(int length) {
return "Model metadata contains training node assignment and method context.";
case 10:
return "Model metadata contains training node assignment, method context and vector data type.";
case 11:
case 12:
return "Model metadata contains mode and compression level";
case 13:
return "Model metadata contains training node assignment, method context, vector data type, and version";
default:
throw new IllegalArgumentException("Unexpected metadata array length: " + length);
}
Expand Down Expand Up @@ -423,6 +443,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD);
Object mode = modelSourceMap.get(KNNConstants.MODE_PARAMETER);
Object compressionLevel = modelSourceMap.get(KNNConstants.COMPRESSION_LEVEL_PARAMETER);
Object version = modelSourceMap.get(KNNConstants.MODEL_VERSION);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
Expand All @@ -447,6 +468,10 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
vectorDataType = VectorDataType.DEFAULT.getValue();
}

if (version == null) {
version = Version.V_EMPTY;
}

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
SpaceType.getSpace(objectToString(space)),
Expand All @@ -459,7 +484,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
(MethodComponentContext) methodComponentContext,
VectorDataType.get(objectToString(vectorDataType)),
Mode.fromName(objectToString(mode)),
CompressionLevel.fromName(objectToString(compressionLevel))
CompressionLevel.fromName(objectToString(compressionLevel)),
Version.fromString(version.toString())
);
return modelMetadata;
}
Expand All @@ -486,6 +512,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(mode.getName());
out.writeOptionalString(compressionLevel.getName());
}
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VERSION)) {
out.writeString(version.toString());
}
}

@Override
Expand Down Expand Up @@ -517,6 +546,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(KNNConstants.COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName());
}
}
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VERSION)) {
String versionString = "unknown";
if (version != Version.V_EMPTY) {
versionString = version.toString();
}
builder.field(KNNConstants.MODEL_VERSION, versionString);
}
return builder;
}
}
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public static ModelMetadata getModelMetadata(final String modelId) {
if (StringUtils.isEmpty(modelId)) {
return null;
}
final Model model = ModelCache.getInstance().get(modelId);
final ModelMetadata modelMetadata = model.getModelMetadata();
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (isModelCreated(modelMetadata) == false) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ public TrainingJob(
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext.getVectorDataType(),
mode,
compressionLevel
compressionLevel,
knnMethodConfigContext.getVersionCreated()
),
null,
this.modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
Expand Down Expand Up @@ -166,11 +165,11 @@ private void train(TrainingJob trainingJob) {
private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> listener, boolean update) throws IOException,
ExecutionException, InterruptedException {
if (update) {
Model model = modelDao.get(trainingJob.getModelId());
if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) {
ModelMetadata modelMetadata = modelDao.getMetadata(trainingJob.getModelId());
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
modelDao.update(trainingJob.getModel(), listener);
} else {
logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState());
logger.info("Model state is {}. Skipping serialization of trained data", modelMetadata.getState());
}
} else {
modelDao.put(trainingJob.getModel(), listener);
Expand Down
3 changes: 3 additions & 0 deletions src/main/resources/mappings/model-index.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
},
"compression_level": {
"type": "keyword"
},
"model_version": {
"type": "keyword"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import org.opensearch.Version;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -69,7 +70,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException
MethodComponentContext.EMPTY,
VectorDataType.FLOAT,
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
CompressionLevel.NOT_CONFIGURED,
Version.V_EMPTY
);

Model model = new Model(modelMetadata, modelBlob, modelId);
Expand Down
Loading

0 comments on commit 6814c8f

Please sign in to comment.